task_service.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # app/services/task_service.py
  2. from typing import List, Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select, update
  5. from app.core.queue_manager import queue_manager
  6. from app.core.biz_exception import NotFoundError
  7. from app.models.task import Task
  8. from app.schemas.task import TaskCreate, TaskUpdate
  9. class TaskService:
  10. # ======================
  11. # 创建任务
  12. # ======================
  13. @staticmethod
  14. async def create(db: AsyncSession, obj_in: TaskCreate) -> Task:
  15. db_obj = Task(
  16. command=obj_in.command,
  17. args=obj_in.args,
  18. status=obj_in.status or 0,
  19. )
  20. db.add(db_obj)
  21. await db.commit()
  22. await db.refresh(db_obj)
  23. # 👇 只在 READY 状态才入队
  24. if db_obj.status == 0:
  25. queue_manager.put(
  26. queue_name=db_obj.command,
  27. task_id=db_obj.id,
  28. priority=0, # 先用默认优先级
  29. )
  30. return db_obj
  31. # ======================
  32. # 根据 ID 获取
  33. # ======================
  34. @staticmethod
  35. async def get_by_id(db: AsyncSession, task_id: int) -> Task:
  36. stmt = select(Task).where(Task.id == task_id)
  37. obj = (await db.execute(stmt)).scalar_one_or_none()
  38. if not obj:
  39. raise NotFoundError("Task not exist")
  40. return obj
  41. # ======================
  42. # 更新任务
  43. # ======================
  44. @staticmethod
  45. async def update(
  46. db: AsyncSession,
  47. task_id: int,
  48. obj_in: TaskUpdate
  49. ) -> Task:
  50. stmt = select(Task).where(Task.id == task_id)
  51. db_obj = (await db.execute(stmt)).scalar_one_or_none()
  52. old_status = db_obj.status
  53. if not db_obj:
  54. raise NotFoundError("Task not exist")
  55. if obj_in.result is not None:
  56. db_obj.result = obj_in.result
  57. if obj_in.status is not None:
  58. db_obj.status = obj_in.status
  59. await db.commit()
  60. await db.refresh(db_obj)
  61. # 👇 核心:从非 READY → READY,才重新入队
  62. if old_status != 0 and db_obj.status == 0:
  63. queue_manager.put(
  64. queue_name=db_obj.command,
  65. task_id=db_obj.id,
  66. priority=0,
  67. )
  68. return db_obj
  69. # ======================
  70. # 获取待处理任务(分页)
  71. # ======================
  72. @staticmethod
  73. async def get_pending(
  74. db: AsyncSession,
  75. command: str,
  76. page: int,
  77. size: int
  78. ) -> List[Task]:
  79. offset = page * size
  80. stmt = (
  81. select(Task)
  82. .where(
  83. Task.command == command,
  84. Task.status == 0
  85. )
  86. .order_by(Task.create_at.asc())
  87. .offset(offset)
  88. .limit(size)
  89. )
  90. return (await db.execute(stmt)).scalars().all()