vas_task_service.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. # app/services/task_service.py
  2. from datetime import datetime
  3. from typing import List, Optional
  4. from sqlalchemy.ext.asyncio import AsyncSession
  5. from sqlalchemy import select
  6. from app.utils.search import apply_keyword_search_stmt
  7. from app.utils.pagination import paginate
  8. from app.core.biz_exception import NotFoundError
  9. from app.models.vas_task import VasTask
  10. from app.models.order import VasOrder
  11. from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
  12. class VasTaskService:
  13. @staticmethod
  14. async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask:
  15. rec = VasTask(
  16. **data.dict(),
  17. status="pending",
  18. created_at=datetime.utcnow(),
  19. )
  20. db.add(rec)
  21. await db.commit()
  22. await db.refresh(rec)
  23. return rec
  24. @staticmethod
  25. async def list_task(
  26. db: AsyncSession,
  27. status: Optional[str] = None,
  28. routing_key: Optional[str] = None,
  29. script_version: Optional[str] = None,
  30. keyword: Optional[str] = None,
  31. page: int = 0,
  32. size: int = 10,
  33. ):
  34. stmt = select(VasTask)
  35. if status:
  36. stmt = stmt.where(VasTask.status == status)
  37. if routing_key:
  38. stmt = stmt.where(VasTask.routing_key == routing_key)
  39. if script_version:
  40. stmt = stmt.where(VasTask.script_version == script_version)
  41. stmt = apply_keyword_search_stmt(
  42. stmt=stmt,
  43. model=VasTask,
  44. keyword=keyword,
  45. fields=["order_id", "routing_key", "user_inputs"],
  46. )
  47. stmt = stmt.order_by(
  48. VasTask.priority.desc(),
  49. VasTask.id.asc(),
  50. )
  51. return await paginate(db, stmt, page, size)
  52. @staticmethod
  53. async def update(
  54. db: AsyncSession,
  55. id: int,
  56. payload: VasTaskUpdate,
  57. ) -> VasTask:
  58. stmt = select(VasTask).where(VasTask.id == id)
  59. result = await db.execute(stmt)
  60. obj = result.scalar_one_or_none()
  61. if not obj:
  62. raise NotFoundError("Task not exist")
  63. data = payload.dict(exclude_unset=True)
  64. for key, value in data.items():
  65. setattr(obj, key, value)
  66. await db.commit()
  67. await db.refresh(obj)
  68. return obj
  69. @staticmethod
  70. async def get_active_task_by_order_id(
  71. db: AsyncSession,
  72. order_id: str,
  73. ) -> List[VasTask]:
  74. stmt = select(VasTask).where(
  75. VasTask.status == "pending",
  76. VasTask.order_id == order_id,
  77. )
  78. result = await db.execute(stmt)
  79. return result.scalars().all()
  80. @staticmethod
  81. async def return_to_queue(db: AsyncSession, id: int) -> VasTask:
  82. stmt = select(VasTask).where(VasTask.id == id)
  83. result = await db.execute(stmt)
  84. rec = result.scalar_one_or_none()
  85. if not rec:
  86. raise NotFoundError("Task not exist")
  87. rec.status = "pending"
  88. rec.attempt_count = (rec.attempt_count or 0) + 1
  89. await db.commit()
  90. await db.refresh(rec)
  91. return rec
  92. @staticmethod
  93. async def manual_confirm(db: AsyncSession, id: int) -> VasTask:
  94. stmt = select(VasTask).where(VasTask.id == id)
  95. result = await db.execute(stmt)
  96. task = result.scalar_one_or_none()
  97. if not task:
  98. raise NotFoundError("Task not exist")
  99. task.status = "completed"
  100. order_stmt = select(VasOrder).where(VasOrder.id == task.order_id)
  101. order_result = await db.execute(order_stmt)
  102. order = order_result.scalar_one_or_none()
  103. if not order:
  104. raise NotFoundError("Order not exist")
  105. order.status = "completed"
  106. await db.commit()
  107. await db.refresh(task)
  108. return task