task_service.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # app/services/task_service.py
  2. from typing import List, Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select, update
  5. from app.core.queue_manager import queue_manager
  6. from app.core.biz_exception import NotFoundError
  7. from app.models.task import Task
  8. from app.schemas.task import TaskCreate, TaskUpdate
  9. class TaskService:
  10. # ======================
  11. # 创建任务
  12. # ======================
  13. @staticmethod
  14. async def create(db: AsyncSession, obj_in: TaskCreate) -> Task:
  15. db_obj = Task(
  16. command=obj_in.command,
  17. args=obj_in.args,
  18. status=obj_in.status or 0,
  19. )
  20. db.add(db_obj)
  21. await db.commit()
  22. await db.refresh(db_obj)
  23. # 👇 只在 READY 状态才入队
  24. if db_obj.status == 0:
  25. queue_manager.put(
  26. queue_name=db_obj.command,
  27. task_id=db_obj.id,
  28. priority=0, # 先用默认优先级
  29. )
  30. return db_obj
  31. # ======================
  32. # 根据 ID 获取
  33. # ======================
  34. @staticmethod
  35. async def get_by_id(db: AsyncSession, task_id: int) -> Task:
  36. stmt = select(Task).where(Task.id == task_id)
  37. obj = (await db.execute(stmt)).scalar_one_or_none()
  38. if not obj:
  39. raise NotFoundError("Task not exist")
  40. return obj
  41. # ======================
  42. # 更新任务
  43. # ======================
  44. @staticmethod
  45. async def update(
  46. db: AsyncSession,
  47. task_id: int,
  48. obj_in: TaskUpdate
  49. ) -> Task:
  50. stmt = select(Task).where(Task.id == task_id)
  51. db_obj = (await db.execute(stmt)).scalar_one_or_none()
  52. if not db_obj:
  53. raise NotFoundError("Task not exist")
  54. if obj_in.result is not None:
  55. db_obj.result = obj_in.result
  56. if obj_in.status is not None:
  57. db_obj.status = obj_in.status
  58. await db.commit()
  59. await db.refresh(db_obj)
  60. # 👇 核心:从非 READY → READY,才重新入队
  61. if old_status != 0 and db_obj.status == 0:
  62. queue_manager.put(
  63. queue_name=db_obj.command,
  64. task_id=db_obj.id,
  65. priority=0,
  66. )
  67. return db_obj
  68. # ======================
  69. # 获取待处理任务(分页)
  70. # ======================
  71. @staticmethod
  72. async def get_pending(
  73. db: AsyncSession,
  74. command: str,
  75. page: int,
  76. size: int
  77. ) -> List[Task]:
  78. offset = page * size
  79. stmt = (
  80. select(Task)
  81. .where(
  82. Task.command == command,
  83. Task.status == 0
  84. )
  85. .order_by(Task.create_at.asc())
  86. .offset(offset)
  87. .limit(size)
  88. )
  89. return (await db.execute(stmt)).scalars().all()