vas_task_service.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # app/services/task_service.py
  2. from datetime import datetime, timedelta
  3. from typing import List, Optional
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy import select, or_, and_
  6. from app.utils.search import apply_keyword_search_stmt
  7. from app.utils.pagination import paginate
  8. from app.core.biz_exception import NotFoundError,BizLogicError
  9. from app.models.vas_task import VasTask
  10. from app.models.order import VasOrder
  11. from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
  12. from app.services.task_handlers import task_processor
  13. class VasTaskService:
  14. @staticmethod
  15. async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask:
  16. rec = VasTask(
  17. **data.dict(),
  18. status="pending",
  19. created_at=datetime.utcnow(),
  20. )
  21. db.add(rec)
  22. await db.commit()
  23. await db.refresh(rec)
  24. return rec
  25. @staticmethod
  26. async def pop_vas_task(session: AsyncSession, routing_key: str, cooldown_seconds: int = 60):
  27. """
  28. 异步获取任务,支持冷却期机制。
  29. :param session: 数据库异步会话
  30. :param routing_key: 队列键值
  31. :param cooldown_seconds: 失败后的冷却时间(秒),默认60秒
  32. :return: VasTask 对象 or None
  33. """
  34. # 计算冷却截止时间:当前时间 - 冷却秒数
  35. # 只有 updated_at 早于这个时间的重试任务,才会被提取
  36. cutoff_time = datetime.utcnow() - timedelta(seconds=cooldown_seconds)
  37. try:
  38. # --- 构造查询语句 ---
  39. stmt = (
  40. select(VasTask)
  41. .where(VasTask.routing_key == routing_key)
  42. .where(VasTask.status == 'pending')
  43. # === 核心逻辑:冷却期筛选 ===
  44. .where(
  45. or_(
  46. # 情况1:这是一个全新任务 (从未尝试过)
  47. VasTask.attempt_count == 0,
  48. # 情况2:这是一个重试任务,且距离上次更新(失败)已经过了冷却期
  49. and_(
  50. VasTask.attempt_count > 0,
  51. VasTask.updated_at < cutoff_time
  52. )
  53. )
  54. )
  55. # 排序:优先级优先(0假设是最高优先级?),其次是先创建的
  56. # 注意:根据你的业务,priority 可能 desc 才是高优先级,这里按 asc 写
  57. .order_by(VasTask.priority.desc(), VasTask.id.asc())
  58. .limit(1)
  59. # MySQL 8.0+ 必加,跳过被锁定的行
  60. .with_for_update(skip_locked=True)
  61. )
  62. result = await session.execute(stmt)
  63. task = result.scalar_one_or_none()
  64. # --- 更新状态 ---
  65. if task:
  66. task.status = 'running' # 标记为已被抓取
  67. task.attempt_count += 1 # 增加尝试次数
  68. task.updated_at = datetime.utcnow() # 更新时间(重置冷却计时起点)
  69. await session.commit()
  70. # session.begin() 结束时自动 commit
  71. return task
  72. raise NotFoundError(message="Task not found")
  73. except Exception as e:
  74. # 记录日志
  75. raise e
  76. @staticmethod
  77. async def list_task(
  78. db: AsyncSession,
  79. status: Optional[str] = None,
  80. routing_key: Optional[str] = None,
  81. script_version: Optional[str] = None,
  82. keyword: Optional[str] = None,
  83. page: int = 0,
  84. size: int = 10,
  85. ):
  86. stmt = select(VasTask)
  87. if status:
  88. stmt = stmt.where(VasTask.status == status)
  89. if routing_key:
  90. stmt = stmt.where(VasTask.routing_key == routing_key)
  91. if script_version:
  92. stmt = stmt.where(VasTask.script_version == script_version)
  93. stmt = apply_keyword_search_stmt(
  94. stmt=stmt,
  95. model=VasTask,
  96. keyword=keyword,
  97. fields=["order_id", "routing_key", "user_inputs"],
  98. )
  99. stmt = stmt.order_by(
  100. VasTask.priority.desc(),
  101. VasTask.id.asc(),
  102. )
  103. return await paginate(db, stmt, page, size)
  104. @staticmethod
  105. async def update(
  106. db: AsyncSession,
  107. id: int,
  108. payload: VasTaskUpdate,
  109. ) -> VasTask:
  110. stmt = select(VasTask).where(VasTask.id == id)
  111. result = await db.execute(stmt)
  112. obj = result.scalar_one_or_none()
  113. if not obj:
  114. raise NotFoundError("Task not exist")
  115. data = payload.dict(exclude_unset=True)
  116. for key, value in data.items():
  117. setattr(obj, key, value)
  118. await db.commit()
  119. await db.refresh(obj)
  120. return obj
  121. @staticmethod
  122. async def get_active_task_by_order_id(
  123. db: AsyncSession,
  124. order_id: str,
  125. ) -> List[VasTask]:
  126. stmt = select(VasTask).where(
  127. VasTask.status == "pending",
  128. VasTask.order_id == order_id,
  129. )
  130. result = await db.execute(stmt)
  131. return result.scalars().all()
  132. @staticmethod
  133. async def return_to_queue(db: AsyncSession, id: int) -> VasTask:
  134. stmt = select(VasTask).where(VasTask.id == id)
  135. result = await db.execute(stmt)
  136. rec = result.scalar_one_or_none()
  137. if not rec:
  138. raise NotFoundError("Task not exist")
  139. if rec.status == "pending":
  140. raise BizLogicError("Task is in queue already")
  141. rec.status = "pending"
  142. await db.commit()
  143. await db.refresh(rec)
  144. return rec
  145. @staticmethod
  146. async def manual_confirm(db: AsyncSession, id: int) -> VasTask:
  147. stmt = select(VasTask).where(VasTask.id == id)
  148. result = await db.execute(stmt)
  149. task = result.scalar_one_or_none()
  150. if not task:
  151. raise NotFoundError("Task not exist")
  152. task.status = "completed"
  153. order_stmt = select(VasOrder).where(VasOrder.id == task.order_id)
  154. order_result = await db.execute(order_stmt)
  155. order = order_result.scalar_one_or_none()
  156. if not order:
  157. raise NotFoundError("Order not exist")
  158. order.status = "completed"
  159. await task_processor.execute(task.routing_key, db, task, order)
  160. await db.commit()
  161. await db.refresh(task)
  162. return task