# app/services/task_service.py from datetime import datetime, date, timedelta from typing import List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, or_, and_ from redis.asyncio import Redis # 引入 Redis 类型 from app.utils.search import apply_keyword_search_stmt from app.utils.pagination import paginate from app.core.biz_exception import NotFoundError,BizLogicError from app.models.vas_task import VasTask from app.models.order import VasOrder from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate from app.services.task_handlers import task_processor class VasTaskService: @staticmethod async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask: rec = VasTask( **data.dict(), status="pending", created_at=datetime.utcnow(), ) db.add(rec) await db.commit() await db.refresh(rec) return rec @staticmethod async def pop_vas_task(session: AsyncSession, routing_key: str, cooldown_seconds: int = 60): """ 异步获取任务,支持冷却期机制。 :param session: 数据库异步会话 :param routing_key: 队列键值 :param cooldown_seconds: 失败后的冷却时间(秒),默认60秒 :return: VasTask 对象 or None """ # 计算冷却截止时间:当前时间 - 冷却秒数 # 只有 updated_at 早于这个时间的重试任务,才会被提取 cutoff_time = datetime.utcnow() - timedelta(seconds=cooldown_seconds) try: # --- 构造查询语句 --- stmt = ( select(VasTask) .where(VasTask.routing_key == routing_key) .where(VasTask.status == 'pending') # === 核心逻辑:冷却期筛选 === .where( or_( # 情况1:这是一个全新任务 (从未尝试过) VasTask.attempt_count == 0, # 情况2:这是一个重试任务,且距离上次更新(失败)已经过了冷却期 and_( VasTask.attempt_count > 0, VasTask.updated_at < cutoff_time ) ) ) # 排序:优先级优先(0假设是最高优先级?),其次是先创建的 # 注意:根据你的业务,priority 可能 desc 才是高优先级,这里按 asc 写 .order_by(VasTask.priority.desc(), VasTask.id.asc()) .limit(1) # MySQL 8.0+ 必加,跳过被锁定的行 .with_for_update(skip_locked=True) ) result = await session.execute(stmt) task = result.scalar_one_or_none() # --- 更新状态 --- if task: task.status = 'running' # 标记为已被抓取 task.attempt_count += 1 # 增加尝试次数 task.updated_at = datetime.utcnow() # 更新时间(重置冷却计时起点) await session.commit() # session.begin() 结束时自动 commit return task raise NotFoundError(message="Task not found") except Exception as e: # 记录日志 raise e @staticmethod async def get_expiring_tasks(db: AsyncSession, threshold_days: int = 3): """ 获取即将过期或已过期的活跃任务 :param threshold_days: 预警阈值,默认 7 天内到期 """ # 1. 查出所有活跃任务 stmt = select(VasTask).where( VasTask.status.in_(['pending', 'running', 'grabbed']) ) tasks = (await db.execute(stmt)).scalars().all() results = [] today = date.today() for task in tasks: user_inputs = task.user_inputs or {} end_date_str = user_inputs.get("expected_end_date") # 如果没有截止日期,跳过 if not end_date_str: continue try: # 解析日期 end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date() # 计算剩余天数 delta = (end_date - today).days # 筛选条件: 已过期 (delta < 0) 或 即将过期 (delta <= threshold) if delta <= threshold_days: # 获取客户姓名 first = user_inputs.get("first_name", "") last = user_inputs.get("last_name", "") email = user_inputs.get("email", "") social_media_account = user_inputs.get("social_media_account", "") full_name = f"{first} {last}".strip() or "未知客户" results.append({ "id": task.id, "order_id": task.order_id, "routing_key": task.routing_key, "status": task.status, "social_media_account": social_media_account, "customer_name": full_name, "expected_end_date": end_date_str, "email": email, "days_left": delta }) except ValueError: continue # 日期格式错误则忽略 # 按剩余天数升序排列 (最急的在前面) results.sort(key=lambda x: x["days_left"]) return results @staticmethod async def list_task( db: AsyncSession, status: Optional[str] = None, routing_key: Optional[str] = None, script_version: Optional[str] = None, keyword: Optional[str] = None, page: int = 0, size: int = 10, ): stmt = select(VasTask) if status: stmt = stmt.where(VasTask.status == status) if routing_key: stmt = stmt.where(VasTask.routing_key == routing_key) if script_version: stmt = stmt.where(VasTask.script_version == script_version) stmt = apply_keyword_search_stmt( stmt=stmt, model=VasTask, keyword=keyword, fields=["order_id", "routing_key", "user_inputs"], ) stmt = stmt.order_by( VasTask.priority.desc(), VasTask.id.asc(), ) return await paginate(db, stmt, page, size) @staticmethod async def update( db: AsyncSession, id: int, payload: VasTaskUpdate, ) -> VasTask: stmt = select(VasTask).where(VasTask.id == id) result = await db.execute(stmt) obj = result.scalar_one_or_none() if not obj: raise NotFoundError("Task not exist") data = payload.dict(exclude_unset=True) for key, value in data.items(): setattr(obj, key, value) await db.commit() await db.refresh(obj) return obj @staticmethod async def get_active_task_by_order_id( db: AsyncSession, order_id: str, ) -> List[VasTask]: stmt = select(VasTask).where( VasTask.status == "pending", VasTask.order_id == order_id, ) result = await db.execute(stmt) return result.scalars().all() @staticmethod async def return_to_queue(db: AsyncSession, id: int) -> VasTask: stmt = select(VasTask).where(VasTask.id == id) result = await db.execute(stmt) rec = result.scalar_one_or_none() if not rec: raise NotFoundError("Task not exist") if rec.status == "pending": raise BizLogicError("Task is in queue already") rec.status = "pending" await db.commit() await db.refresh(rec) return rec @staticmethod async def manual_confirm(db: AsyncSession, id: int, redis_client: Redis) -> VasTask: stmt = select(VasTask).where(VasTask.id == id) result = await db.execute(stmt) task = result.scalar_one_or_none() if not task: raise NotFoundError("Task not exist") task.status = "completed" order_stmt = select(VasOrder).where(VasOrder.id == task.order_id) order_result = await db.execute(order_stmt) order = order_result.scalar_one_or_none() if not order: raise NotFoundError("Order not exist") order.status = "completed" await task_processor.execute(task.routing_key, db, redis_client, task, order) await db.commit() await db.refresh(task) return task