| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- # app/services/task_service.py
- from datetime import datetime, date, timedelta
- from typing import List, Optional
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import func, text, select, or_, and_
- from redis.asyncio import Redis # 引入 Redis 类型
- from app.utils.search import apply_keyword_search_stmt
- from app.utils.pagination import paginate
- from app.core.biz_exception import NotFoundError,BizLogicError
- from app.models.vas_task import VasTask
- from app.models.order import VasOrder
- from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
- from app.services.task_handlers import task_processor
- class VasTaskService:
- @staticmethod
- async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask:
- rec = VasTask(
- **data.dict(),
- status="pending",
- created_at=datetime.utcnow(),
- )
- db.add(rec)
- await db.commit()
- await db.refresh(rec)
- return rec
-
- @staticmethod
- async def pop_vas_task(db: AsyncSession, routing_key: str, cooldown_seconds: int = 60):
- stmt = (
- select(VasTask)
- .where(VasTask.routing_key == routing_key)
- .where(VasTask.status == 'pending')
- .where(
- or_(
- VasTask.attempt_count == 0,
- and_(
- VasTask.attempt_count > 0,
- VasTask.updated_at < (func.utc_timestamp() - text(f"INTERVAL {cooldown_seconds} SECOND"))
- )
- )
- )
- .order_by(VasTask.priority.desc(), VasTask.id.asc())
- .limit(1)
- .with_for_update(skip_locked=True)
- )
- result = await db.execute(stmt)
- task = result.scalar_one_or_none()
- if not task:
- raise NotFoundError(message="Task not found")
- task.status = 'running'
- task.attempt_count += 1
- task.updated_at = func.utc_timestamp()
- await db.commit()
- await db.refresh(task)
- return task
-
- @staticmethod
- async def get_expiring_tasks(db: AsyncSession, threshold_days: int = 3):
- """
- 获取即将过期或已过期的活跃任务
- :param threshold_days: 预警阈值,默认 7 天内到期
- """
- # 1. 查出所有活跃任务
- stmt = select(VasTask).where(
- VasTask.status.in_(['pending', 'running', 'grabbed'])
- )
- tasks = (await db.execute(stmt)).scalars().all()
-
- results = []
- today = date.today()
- for task in tasks:
- user_inputs = task.user_inputs or {}
- end_date_str = user_inputs.get("expected_end_date")
-
- # 如果没有截止日期,跳过
- if not end_date_str:
- continue
- try:
- # 解析日期
- end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date()
-
- # 计算剩余天数
- delta = (end_date - today).days
-
- # 筛选条件: 已过期 (delta < 0) 或 即将过期 (delta <= threshold)
- if delta <= threshold_days:
- # 获取客户姓名
- first = user_inputs.get("first_name", "")
- last = user_inputs.get("last_name", "")
- email = user_inputs.get("email", "")
- social_media_account = user_inputs.get("social_media_account", "")
- full_name = f"{first} {last}".strip() or "未知客户"
- results.append({
- "id": task.id,
- "order_id": task.order_id,
- "routing_key": task.routing_key,
- "status": task.status,
- "social_media_account": social_media_account,
- "customer_name": full_name,
- "expected_end_date": end_date_str,
- "email": email,
- "days_left": delta
- })
- except ValueError:
- continue # 日期格式错误则忽略
- # 按剩余天数升序排列 (最急的在前面)
- results.sort(key=lambda x: x["days_left"])
- return results
- @staticmethod
- async def list_task(
- db: AsyncSession,
- status: Optional[str] = None,
- routing_key: Optional[str] = None,
- script_version: Optional[str] = None,
- keyword: Optional[str] = None,
- page: int = 0,
- size: int = 10,
- ):
- stmt = select(VasTask)
- if status:
- stmt = stmt.where(VasTask.status == status)
- if routing_key:
- stmt = stmt.where(VasTask.routing_key == routing_key)
- if script_version:
- stmt = stmt.where(VasTask.script_version == script_version)
- stmt = apply_keyword_search_stmt(
- stmt=stmt,
- model=VasTask,
- keyword=keyword,
- fields=["order_id", "routing_key", "user_inputs"],
- )
- stmt = stmt.order_by(
- VasTask.priority.desc(),
- VasTask.id.asc(),
- )
- return await paginate(db, stmt, page, size)
- @staticmethod
- async def update(
- db: AsyncSession,
- id: int,
- payload: VasTaskUpdate,
- ) -> VasTask:
- stmt = select(VasTask).where(VasTask.id == id)
- result = await db.execute(stmt)
- obj = result.scalar_one_or_none()
- if not obj:
- raise NotFoundError("Task not exist")
- data = payload.dict(exclude_unset=True)
- for key, value in data.items():
- setattr(obj, key, value)
- await db.commit()
- await db.refresh(obj)
- return obj
- @staticmethod
- async def get_task_by_order_id(
- db: AsyncSession,
- order_id: str,
- ) -> List[VasTask]:
- stmt = select(VasTask).where(
- VasTask.order_id == order_id,
- )
- result = await db.execute(stmt)
- return result.scalars().all()
-
- @staticmethod
- async def get_task_by_id(
- db: AsyncSession,
- task_id: int,
- ) -> VasTask:
- stmt = select(VasTask).where(
- VasTask.id == task_id,
- )
- result = await db.execute(stmt)
- rec = result.scalar_one_or_none()
- if not rec:
- raise NotFoundError("Task not exist")
- return rec
- @staticmethod
- async def return_to_queue(db: AsyncSession, id: int) -> VasTask:
- stmt = select(VasTask).where(VasTask.id == id)
- result = await db.execute(stmt)
- rec = result.scalar_one_or_none()
- if not rec:
- raise NotFoundError("Task not exist")
-
- rec.status = "pending"
- await db.commit()
- await db.refresh(rec)
- return rec
-
- @staticmethod
- async def pause(db: AsyncSession, id: int) -> VasTask:
- stmt = select(VasTask).where(VasTask.id == id)
- result = await db.execute(stmt)
- rec = result.scalar_one_or_none()
- if not rec:
- raise NotFoundError("Task not exist")
-
- rec.status = "pause"
- await db.commit()
- await db.refresh(rec)
- return rec
- @staticmethod
- async def manual_confirm(db: AsyncSession, id: int, redis_client: Redis) -> VasTask:
- stmt = select(VasTask).where(VasTask.id == id)
- result = await db.execute(stmt)
- task = result.scalar_one_or_none()
- if not task:
- raise NotFoundError("Task not exist")
- task.status = "completed"
- order_stmt = select(VasOrder).where(VasOrder.id == task.order_id)
- order_result = await db.execute(order_stmt)
- order = order_result.scalar_one_or_none()
- if not order:
- raise NotFoundError("Order not exist")
- order.status = "completed"
-
- await task_processor.execute(task.routing_key, db, redis_client, task, order)
- await db.commit()
- await db.refresh(task)
- return task
-
|