task_service.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # app/services/task_service.py
  2. from typing import List, Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select, update
  5. from datetime import datetime
  6. from app.core.biz_exception import NotFoundError
  7. from app.models.task import Task
  8. from app.schemas.task import TaskCreate, TaskUpdate
  9. class TaskService:
  10. # ======================
  11. # 创建任务
  12. # ======================
  13. @staticmethod
  14. async def create(db: AsyncSession, obj_in: TaskCreate) -> Task:
  15. db_obj = Task(
  16. command=obj_in.command,
  17. args=obj_in.args,
  18. status=obj_in.status or 0,
  19. )
  20. db.add(db_obj)
  21. await db.commit()
  22. await db.refresh(db_obj)
  23. return db_obj
  24. async def pop_task(db: AsyncSession, queue_name: str) -> Task:
  25. """
  26. 异步 POP 方法 (MySQL 8.0+ 专用)
  27. 原子性获取并锁定一个任务
  28. """
  29. try:
  30. # --- 第一步:查询并抢占锁 (SELECT ... FOR UPDATE SKIP LOCKED) ---
  31. stmt = (
  32. select(Task)
  33. .where(Task.command == queue_name) # 指定队列/命令类型
  34. .where(Task.status == 0) # 只找待执行的
  35. .order_by(Task.id.asc()) # 优先级 > 时间
  36. .limit(1)
  37. # MySQL 8.0+ 必须加 skip_locked=True,否则并发高时会造成大量等待
  38. .with_for_update(skip_locked=True)
  39. )
  40. result = await db.execute(stmt)
  41. task = result.scalar_one_or_none()
  42. # --- 第二步:更新状态 (UPDATE) ---
  43. if task:
  44. task.status = 1 # 标记为执行中
  45. task.update_at = datetime.now()
  46. # session.begin() 退出时会自动执行 commit()
  47. # 此时数据库中的状态已变更
  48. await db.commit()
  49. return task
  50. # 如果没抢到任务
  51. raise NotFoundError(message="Task not found")
  52. except Exception as e:
  53. # 记录日志
  54. # session.begin() 会自动 rollback,无需手动调用
  55. raise e
  56. # ======================
  57. # 根据 ID 获取
  58. # ======================
  59. @staticmethod
  60. async def get_by_id(db: AsyncSession, task_id: int) -> Task:
  61. stmt = select(Task).where(Task.id == task_id)
  62. obj = (await db.execute(stmt)).scalar_one_or_none()
  63. if not obj:
  64. raise NotFoundError("Task not exist")
  65. return obj
  66. # ======================
  67. # 更新任务
  68. # ======================
  69. @staticmethod
  70. async def update(
  71. db: AsyncSession,
  72. task_id: int,
  73. obj_in: TaskUpdate
  74. ) -> Task:
  75. stmt = select(Task).where(Task.id == task_id)
  76. db_obj = (await db.execute(stmt)).scalar_one_or_none()
  77. old_status = db_obj.status
  78. if not db_obj:
  79. raise NotFoundError("Task not exist")
  80. if obj_in.result is not None:
  81. db_obj.result = obj_in.result
  82. if obj_in.status is not None:
  83. db_obj.status = obj_in.status
  84. await db.commit()
  85. await db.refresh(db_obj)
  86. return db_obj
  87. # ======================
  88. # 获取待处理任务(分页)
  89. # ======================
  90. @staticmethod
  91. async def get_pending(
  92. db: AsyncSession,
  93. command: str,
  94. page: int,
  95. size: int
  96. ) -> List[Task]:
  97. offset = page * size
  98. stmt = (
  99. select(Task)
  100. .where(
  101. Task.command == command,
  102. Task.status == 0
  103. )
  104. .order_by(Task.create_at.asc())
  105. .offset(offset)
  106. .limit(size)
  107. )
  108. return (await db.execute(stmt)).scalars().all()