vas_task_service.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # app/services/task_service.py
  2. from datetime import datetime, date, timedelta
  3. from typing import List, Optional
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy import func, text, select, or_, and_
  6. from redis.asyncio import Redis # 引入 Redis 类型
  7. from app.utils.search import apply_keyword_search_stmt
  8. from app.utils.pagination import paginate
  9. from app.core.biz_exception import NotFoundError,BizLogicError
  10. from app.models.vas_task import VasTask
  11. from app.models.order import VasOrder
  12. from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
  13. from app.services.task_handlers import task_processor
  14. class VasTaskService:
  15. @staticmethod
  16. async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask:
  17. rec = VasTask(
  18. **data.dict(),
  19. status="pending",
  20. created_at=datetime.utcnow(),
  21. )
  22. db.add(rec)
  23. await db.commit()
  24. await db.refresh(rec)
  25. return rec
  26. @staticmethod
  27. async def pop_vas_task(db: AsyncSession, routing_key: str, cooldown_seconds: int = 60):
  28. stmt = (
  29. select(VasTask)
  30. .where(VasTask.routing_key == routing_key)
  31. .where(VasTask.status == 'pending')
  32. .where(
  33. or_(
  34. VasTask.attempt_count == 0,
  35. and_(
  36. VasTask.attempt_count > 0,
  37. VasTask.updated_at < (func.utc_timestamp() - text(f"INTERVAL {cooldown_seconds} SECOND"))
  38. )
  39. )
  40. )
  41. .order_by(VasTask.priority.desc(), VasTask.id.asc())
  42. .limit(1)
  43. .with_for_update(skip_locked=True)
  44. )
  45. result = await db.execute(stmt)
  46. task = result.scalar_one_or_none()
  47. if not task:
  48. raise NotFoundError(message="Task not found")
  49. task.status = 'running'
  50. task.attempt_count += 1
  51. task.updated_at = func.utc_timestamp()
  52. await db.commit()
  53. await db.refresh(task)
  54. return task
  55. @staticmethod
  56. async def get_expiring_tasks(db: AsyncSession, threshold_days: int = 3):
  57. """
  58. 获取即将过期或已过期的活跃任务
  59. :param threshold_days: 预警阈值,默认 7 天内到期
  60. """
  61. # 1. 查出所有活跃任务
  62. stmt = select(VasTask).where(
  63. VasTask.status.in_(['pending', 'running', 'grabbed'])
  64. )
  65. tasks = (await db.execute(stmt)).scalars().all()
  66. results = []
  67. today = date.today()
  68. for task in tasks:
  69. user_inputs = task.user_inputs or {}
  70. end_date_str = user_inputs.get("expected_end_date")
  71. # 如果没有截止日期,跳过
  72. if not end_date_str:
  73. continue
  74. try:
  75. # 解析日期
  76. end_date = datetime.strptime(end_date_str, "%Y-%m-%d").date()
  77. # 计算剩余天数
  78. delta = (end_date - today).days
  79. # 筛选条件: 已过期 (delta < 0) 或 即将过期 (delta <= threshold)
  80. if delta <= threshold_days:
  81. # 获取客户姓名
  82. first = user_inputs.get("first_name", "")
  83. last = user_inputs.get("last_name", "")
  84. email = user_inputs.get("email", "")
  85. social_media_account = user_inputs.get("social_media_account", "")
  86. full_name = f"{first} {last}".strip() or "未知客户"
  87. results.append({
  88. "id": task.id,
  89. "order_id": task.order_id,
  90. "routing_key": task.routing_key,
  91. "status": task.status,
  92. "social_media_account": social_media_account,
  93. "customer_name": full_name,
  94. "expected_end_date": end_date_str,
  95. "email": email,
  96. "days_left": delta
  97. })
  98. except ValueError:
  99. continue # 日期格式错误则忽略
  100. # 按剩余天数升序排列 (最急的在前面)
  101. results.sort(key=lambda x: x["days_left"])
  102. return results
  103. @staticmethod
  104. async def list_task(
  105. db: AsyncSession,
  106. status: Optional[str] = None,
  107. routing_key: Optional[str] = None,
  108. script_version: Optional[str] = None,
  109. keyword: Optional[str] = None,
  110. page: int = 0,
  111. size: int = 10,
  112. ):
  113. stmt = select(VasTask)
  114. if status:
  115. stmt = stmt.where(VasTask.status == status)
  116. if routing_key:
  117. stmt = stmt.where(VasTask.routing_key == routing_key)
  118. if script_version:
  119. stmt = stmt.where(VasTask.script_version == script_version)
  120. stmt = apply_keyword_search_stmt(
  121. stmt=stmt,
  122. model=VasTask,
  123. keyword=keyword,
  124. fields=["order_id", "routing_key", "user_inputs"],
  125. )
  126. stmt = stmt.order_by(
  127. VasTask.priority.desc(),
  128. VasTask.id.asc(),
  129. )
  130. return await paginate(db, stmt, page, size)
  131. @staticmethod
  132. async def update(
  133. db: AsyncSession,
  134. id: int,
  135. payload: VasTaskUpdate,
  136. ) -> VasTask:
  137. stmt = select(VasTask).where(VasTask.id == id)
  138. result = await db.execute(stmt)
  139. obj = result.scalar_one_or_none()
  140. if not obj:
  141. raise NotFoundError("Task not exist")
  142. data = payload.dict(exclude_unset=True)
  143. for key, value in data.items():
  144. setattr(obj, key, value)
  145. await db.commit()
  146. await db.refresh(obj)
  147. return obj
  148. @staticmethod
  149. async def get_task_by_order_id(
  150. db: AsyncSession,
  151. order_id: str,
  152. ) -> List[VasTask]:
  153. stmt = select(VasTask).where(
  154. VasTask.order_id == order_id,
  155. )
  156. result = await db.execute(stmt)
  157. return result.scalars().all()
  158. @staticmethod
  159. async def get_task_by_id(
  160. db: AsyncSession,
  161. task_id: int,
  162. ) -> VasTask:
  163. stmt = select(VasTask).where(
  164. VasTask.id == task_id,
  165. )
  166. result = await db.execute(stmt)
  167. rec = result.scalar_one_or_none()
  168. if not rec:
  169. raise NotFoundError("Task not exist")
  170. return rec
  171. @staticmethod
  172. async def return_to_queue(db: AsyncSession, id: int) -> VasTask:
  173. stmt = select(VasTask).where(VasTask.id == id)
  174. result = await db.execute(stmt)
  175. rec = result.scalar_one_or_none()
  176. if not rec:
  177. raise NotFoundError("Task not exist")
  178. rec.status = "pending"
  179. await db.commit()
  180. await db.refresh(rec)
  181. return rec
  182. @staticmethod
  183. async def pause(db: AsyncSession, id: int) -> VasTask:
  184. stmt = select(VasTask).where(VasTask.id == id)
  185. result = await db.execute(stmt)
  186. rec = result.scalar_one_or_none()
  187. if not rec:
  188. raise NotFoundError("Task not exist")
  189. rec.status = "pause"
  190. await db.commit()
  191. await db.refresh(rec)
  192. return rec
  193. @staticmethod
  194. async def manual_confirm(db: AsyncSession, id: int, redis_client: Redis) -> VasTask:
  195. stmt = select(VasTask).where(VasTask.id == id)
  196. result = await db.execute(stmt)
  197. task = result.scalar_one_or_none()
  198. if not task:
  199. raise NotFoundError("Task not exist")
  200. task.status = "completed"
  201. order_stmt = select(VasOrder).where(VasOrder.id == task.order_id)
  202. order_result = await db.execute(order_stmt)
  203. order = order_result.scalar_one_or_none()
  204. if not order:
  205. raise NotFoundError("Order not exist")
  206. order.status = "completed"
  207. await task_processor.execute(task.routing_key, db, redis_client, task, order)
  208. await db.commit()
  209. await db.refresh(task)
  210. return task