task_service.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # app/services/task_service.py
  2. from typing import List, Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select, update
  5. from app.core.biz_exception import NotFoundError
  6. from app.models.task import Task
  7. from app.schemas.task import TaskCreate, TaskUpdate
  8. class TaskService:
  9. # ======================
  10. # 创建任务
  11. # ======================
  12. @staticmethod
  13. async def create(db: AsyncSession, obj_in: TaskCreate) -> Task:
  14. db_obj = Task(
  15. command=obj_in.command,
  16. args=obj_in.args,
  17. status=obj_in.status or 0,
  18. )
  19. db.add(db_obj)
  20. await db.commit()
  21. await db.refresh(db_obj)
  22. return db_obj
  23. # ======================
  24. # 根据 ID 获取
  25. # ======================
  26. @staticmethod
  27. async def get_by_id(db: AsyncSession, task_id: int) -> Task:
  28. stmt = select(Task).where(Task.id == task_id)
  29. obj = (await db.execute(stmt)).scalar_one_or_none()
  30. if not obj:
  31. raise NotFoundError("Task not exist")
  32. return obj
  33. # ======================
  34. # 更新任务
  35. # ======================
  36. @staticmethod
  37. async def update(
  38. db: AsyncSession,
  39. task_id: int,
  40. obj_in: TaskUpdate
  41. ) -> Task:
  42. stmt = select(Task).where(Task.id == task_id)
  43. db_obj = (await db.execute(stmt)).scalar_one_or_none()
  44. if not db_obj:
  45. raise NotFoundError("Task not exist")
  46. if obj_in.result is not None:
  47. db_obj.result = obj_in.result
  48. if obj_in.status is not None:
  49. db_obj.status = obj_in.status
  50. await db.commit()
  51. await db.refresh(db_obj)
  52. return db_obj
  53. # ======================
  54. # 获取待处理任务(分页)
  55. # ======================
  56. @staticmethod
  57. async def get_pending(
  58. db: AsyncSession,
  59. command: str,
  60. page: int,
  61. size: int
  62. ) -> List[Task]:
  63. offset = page * size
  64. stmt = (
  65. select(Task)
  66. .where(
  67. Task.command == command,
  68. Task.status == 0
  69. )
  70. .order_by(Task.create_at.asc())
  71. .offset(offset)
  72. .limit(size)
  73. )
  74. return (await db.execute(stmt)).scalars().all()