vas_task_service.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # app/services/task_service.py
  2. from datetime import datetime, date, timedelta
  3. from typing import List, Optional
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy import select, or_, and_
  6. from redis.asyncio import Redis # 引入 Redis 类型
  7. from app.utils.search import apply_keyword_search_stmt
  8. from app.utils.pagination import paginate
  9. from app.core.biz_exception import NotFoundError,BizLogicError
  10. from app.models.vas_task import VasTask
  11. from app.models.order import VasOrder
  12. from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
  13. from app.services.task_handlers import task_processor
  14. class VasTaskService:
  15. @staticmethod
  16. async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask:
  17. rec = VasTask(
  18. **data.dict(),
  19. status="pending",
  20. created_at=datetime.utcnow(),
  21. )
  22. db.add(rec)
  23. await db.commit()
  24. await db.refresh(rec)
  25. return rec
  26. @staticmethod
  27. async def pop_vas_task(session: AsyncSession, routing_key: str, cooldown_seconds: int = 60):
  28. """
  29. 异步获取任务,支持冷却期机制。
  30. :param session: 数据库异步会话
  31. :param routing_key: 队列键值
  32. :param cooldown_seconds: 失败后的冷却时间(秒),默认60秒
  33. :return: VasTask 对象 or None
  34. """
  35. # 计算冷却截止时间:当前时间 - 冷却秒数
  36. # 只有 updated_at 早于这个时间的重试任务,才会被提取
  37. cutoff_time = datetime.utcnow() - timedelta(seconds=cooldown_seconds)
  38. try:
  39. # --- 构造查询语句 ---
  40. stmt = (
  41. select(VasTask)
  42. .where(VasTask.routing_key == routing_key)
  43. .where(VasTask.status == 'pending')
  44. # === 核心逻辑:冷却期筛选 ===
  45. .where(
  46. or_(
  47. # 情况1:这是一个全新任务 (从未尝试过)
  48. VasTask.attempt_count == 0,
  49. # 情况2:这是一个重试任务,且距离上次更新(失败)已经过了冷却期
  50. and_(
  51. VasTask.attempt_count > 0,
  52. VasTask.updated_at < cutoff_time
  53. )
  54. )
  55. )
  56. # 排序:优先级优先(0假设是最高优先级?),其次是先创建的
  57. # 注意:根据你的业务,priority 可能 desc 才是高优先级,这里按 asc 写
  58. .order_by(VasTask.priority.desc(), VasTask.id.asc())
  59. .limit(1)
  60. # MySQL 8.0+ 必加,跳过被锁定的行
  61. .with_for_update(skip_locked=True)
  62. )
  63. result = await session.execute(stmt)
  64. task = result.scalar_one_or_none()
  65. # --- 更新状态 ---
  66. if task:
  67. task.status = 'running' # 标记为已被抓取
  68. task.attempt_count += 1 # 增加尝试次数
  69. task.updated_at = datetime.utcnow() # 更新时间(重置冷却计时起点)
  70. await session.commit()
  71. # session.begin() 结束时自动 commit
  72. return task
  73. raise NotFoundError(message="Task not found")
  74. except Exception as e:
  75. # 记录日志
  76. raise e
  77. @staticmethod
  78. async def get_expiring_tasks(db: AsyncSession, threshold_days: int = 3):
  79. """
  80. 获取即将过期或已过期的活跃任务
  81. :param threshold_days: 预警阈值,默认 7 天内到期
  82. """
  83. # 1. 查出所有活跃任务
  84. stmt = select(VasTask).where(
  85. VasTask.status.in_(['pending', 'running', 'grabbed'])
  86. )
  87. tasks = (await db.execute(stmt)).scalars().all()
  88. results = []
  89. today = date.today()
  90. for task in tasks:
  91. user_inputs = task.user_inputs or {}
  92. end_date_str = user_inputs.get("expected_end_date")
  93. # 如果没有截止日期,跳过
  94. if not end_date_str:
  95. continue
  96. try:
  97. # 解析日期
  98. end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date()
  99. # 计算剩余天数
  100. delta = (end_date - today).days
  101. # 筛选条件: 已过期 (delta < 0) 或 即将过期 (delta <= threshold)
  102. if delta <= threshold_days:
  103. # 获取客户姓名
  104. first = user_inputs.get("first_name", "")
  105. last = user_inputs.get("last_name", "")
  106. email = user_inputs.get("email", "")
  107. social_media_account = user_inputs.get("social_media_account", "")
  108. full_name = f"{first} {last}".strip() or "未知客户"
  109. results.append({
  110. "id": task.id,
  111. "order_id": task.order_id,
  112. "routing_key": task.routing_key,
  113. "status": task.status,
  114. "social_media_account": social_media_account,
  115. "customer_name": full_name,
  116. "expected_end_date": end_date_str,
  117. "email": email,
  118. "days_left": delta
  119. })
  120. except ValueError:
  121. continue # 日期格式错误则忽略
  122. # 按剩余天数升序排列 (最急的在前面)
  123. results.sort(key=lambda x: x["days_left"])
  124. return results
  125. @staticmethod
  126. async def list_task(
  127. db: AsyncSession,
  128. status: Optional[str] = None,
  129. routing_key: Optional[str] = None,
  130. script_version: Optional[str] = None,
  131. keyword: Optional[str] = None,
  132. page: int = 0,
  133. size: int = 10,
  134. ):
  135. stmt = select(VasTask)
  136. if status:
  137. stmt = stmt.where(VasTask.status == status)
  138. if routing_key:
  139. stmt = stmt.where(VasTask.routing_key == routing_key)
  140. if script_version:
  141. stmt = stmt.where(VasTask.script_version == script_version)
  142. stmt = apply_keyword_search_stmt(
  143. stmt=stmt,
  144. model=VasTask,
  145. keyword=keyword,
  146. fields=["order_id", "routing_key", "user_inputs"],
  147. )
  148. stmt = stmt.order_by(
  149. VasTask.priority.desc(),
  150. VasTask.id.asc(),
  151. )
  152. return await paginate(db, stmt, page, size)
  153. @staticmethod
  154. async def update(
  155. db: AsyncSession,
  156. id: int,
  157. payload: VasTaskUpdate,
  158. ) -> VasTask:
  159. stmt = select(VasTask).where(VasTask.id == id)
  160. result = await db.execute(stmt)
  161. obj = result.scalar_one_or_none()
  162. if not obj:
  163. raise NotFoundError("Task not exist")
  164. data = payload.dict(exclude_unset=True)
  165. for key, value in data.items():
  166. setattr(obj, key, value)
  167. await db.commit()
  168. await db.refresh(obj)
  169. return obj
  170. @staticmethod
  171. async def get_active_task_by_order_id(
  172. db: AsyncSession,
  173. order_id: str,
  174. ) -> List[VasTask]:
  175. stmt = select(VasTask).where(
  176. VasTask.status == "pending",
  177. VasTask.order_id == order_id,
  178. )
  179. result = await db.execute(stmt)
  180. return result.scalars().all()
  181. @staticmethod
  182. async def return_to_queue(db: AsyncSession, id: int) -> VasTask:
  183. stmt = select(VasTask).where(VasTask.id == id)
  184. result = await db.execute(stmt)
  185. rec = result.scalar_one_or_none()
  186. if not rec:
  187. raise NotFoundError("Task not exist")
  188. if rec.status == "pending":
  189. raise BizLogicError("Task is in queue already")
  190. rec.status = "pending"
  191. await db.commit()
  192. await db.refresh(rec)
  193. return rec
  194. @staticmethod
  195. async def manual_confirm(db: AsyncSession, id: int, redis_client: Redis) -> VasTask:
  196. stmt = select(VasTask).where(VasTask.id == id)
  197. result = await db.execute(stmt)
  198. task = result.scalar_one_or_none()
  199. if not task:
  200. raise NotFoundError("Task not exist")
  201. task.status = "completed"
  202. order_stmt = select(VasOrder).where(VasOrder.id == task.order_id)
  203. order_result = await db.execute(order_stmt)
  204. order = order_result.scalar_one_or_none()
  205. if not order:
  206. raise NotFoundError("Order not exist")
  207. order.status = "completed"
  208. await task_processor.execute(task.routing_key, db, redis_client, task, order)
  209. await db.commit()
  210. await db.refresh(task)
  211. return task