# app/services/task_service.py from datetime import datetime, date, timedelta from typing import List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import func, text, select, or_, and_ from redis.asyncio import Redis # 引入 Redis 类型 from app.utils.search import apply_keyword_search_stmt from app.utils.pagination import paginate from app.core.biz_exception import NotFoundError,BizLogicError from app.models.vas_task import VasTask from app.models.order import VasOrder from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate from app.services.task_handlers import task_processor class VasTaskService: @staticmethod async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask: rec = VasTask( **data.dict(), status="pending", created_at=datetime.utcnow(), ) db.add(rec) await db.commit() await db.refresh(rec) return rec @staticmethod async def pop_vas_task(db: AsyncSession, routing_key: str, cooldown_seconds: int = 60): stmt = ( select(VasTask) .where(VasTask.routing_key == routing_key) .where(VasTask.status == 'pending') .where( or_( VasTask.attempt_count == 0, and_( VasTask.attempt_count > 0, VasTask.updated_at < (func.utc_timestamp() - text(f"INTERVAL {cooldown_seconds} SECOND")) ) ) ) .order_by(VasTask.priority.desc(), VasTask.id.asc()) .limit(1) .with_for_update(skip_locked=True) ) result = await db.execute(stmt) task = result.scalar_one_or_none() if not task: raise NotFoundError(message="Task not found") task.status = 'running' task.attempt_count += 1 task.updated_at = func.utc_timestamp() await db.commit() await db.refresh(task) return task @staticmethod async def get_expiring_tasks(db: AsyncSession, threshold_days: int = 3): """ 获取即将过期或已过期的活跃任务 :param threshold_days: 预警阈值,默认 7 天内到期 """ # 1. 查出所有活跃任务 stmt = select(VasTask).where( VasTask.status.in_(['pending', 'running', 'grabbed']) ) tasks = (await db.execute(stmt)).scalars().all() results = [] today = date.today() for task in tasks: user_inputs = task.user_inputs or {} end_date_str = user_inputs.get("expected_end_date") # 如果没有截止日期,跳过 if not end_date_str: continue try: # 解析日期 end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date() # 计算剩余天数 delta = (end_date - today).days # 筛选条件: 已过期 (delta < 0) 或 即将过期 (delta <= threshold) if delta <= threshold_days: # 获取客户姓名 first = user_inputs.get("first_name", "") last = user_inputs.get("last_name", "") email = user_inputs.get("email", "") social_media_account = user_inputs.get("social_media_account", "") full_name = f"{first} {last}".strip() or "未知客户" results.append({ "id": task.id, "order_id": task.order_id, "routing_key": task.routing_key, "status": task.status, "social_media_account": social_media_account, "customer_name": full_name, "expected_end_date": end_date_str, "email": email, "days_left": delta }) except ValueError: continue # 日期格式错误则忽略 # 按剩余天数升序排列 (最急的在前面) results.sort(key=lambda x: x["days_left"]) return results @staticmethod async def list_task( db: AsyncSession, status: Optional[str] = None, routing_key: Optional[str] = None, script_version: Optional[str] = None, keyword: Optional[str] = None, page: int = 0, size: int = 10, ): stmt = select(VasTask) if status: stmt = stmt.where(VasTask.status == status) if routing_key: stmt = stmt.where(VasTask.routing_key == routing_key) if script_version: stmt = stmt.where(VasTask.script_version == script_version) stmt = apply_keyword_search_stmt( stmt=stmt, model=VasTask, keyword=keyword, fields=["order_id", "routing_key", "user_inputs"], ) stmt = stmt.order_by( VasTask.priority.desc(), VasTask.id.asc(), ) return await paginate(db, stmt, page, size) @staticmethod async def update( db: AsyncSession, id: int, payload: VasTaskUpdate, ) -> VasTask: stmt = select(VasTask).where(VasTask.id == id) result = await db.execute(stmt) obj = result.scalar_one_or_none() if not obj: raise NotFoundError("Task not exist") data = payload.dict(exclude_unset=True) for key, value in data.items(): setattr(obj, key, value) await db.commit() await db.refresh(obj) return obj @staticmethod async def get_task_by_order_id( db: AsyncSession, order_id: str, ) -> List[VasTask]: stmt = select(VasTask).where( VasTask.order_id == order_id, ) result = await db.execute(stmt) return result.scalars().all() @staticmethod async def get_task_by_id( db: AsyncSession, task_id: int, ) -> VasTask: stmt = select(VasTask).where( VasTask.id == task_id, ) result = await db.execute(stmt) rec = result.scalar_one_or_none() if not rec: raise NotFoundError("Task not exist") return rec @staticmethod async def return_to_queue(db: AsyncSession, id: int) -> VasTask: stmt = select(VasTask).where(VasTask.id == id) result = await db.execute(stmt) rec = result.scalar_one_or_none() if not rec: raise NotFoundError("Task not exist") rec.status = "pending" await db.commit() await db.refresh(rec) return rec @staticmethod async def pause(db: AsyncSession, id: int) -> VasTask: stmt = select(VasTask).where(VasTask.id == id) result = await db.execute(stmt) rec = result.scalar_one_or_none() if not rec: raise NotFoundError("Task not exist") rec.status = "pause" await db.commit() await db.refresh(rec) return rec @staticmethod async def manual_confirm(db: AsyncSession, id: int, redis_client: Redis) -> VasTask: stmt = select(VasTask).where(VasTask.id == id) result = await db.execute(stmt) task = result.scalar_one_or_none() if not task: raise NotFoundError("Task not exist") task.status = "completed" order_stmt = select(VasOrder).where(VasOrder.id == task.order_id) order_result = await db.execute(order_stmt) order = order_result.scalar_one_or_none() if not order: raise NotFoundError("Order not exist") order.status = "completed" await task_processor.execute(task.routing_key, db, redis_client, task, order) await db.commit() await db.refresh(task) return task