queue_service.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from sqlalchemy.ext.asyncio import AsyncSession
  2. from sqlalchemy import select
  3. from app.core.logger import logger
  4. from app.core.biz_exception import NotFoundError, BizLogicError
  5. from app.core.queue_manager import queue_manager
  6. from app.models.task import Task
  7. from app.models.vas_task import VasTask
  8. class QueueService:
  9. async def rebuild_task_queue(db: AsyncSession, queue_name: str):
  10. if queue_manager.is_initialized(queue_name):
  11. return
  12. queue = queue_manager.get_queue(queue_name)
  13. queue.clear()
  14. stmt = (
  15. select(Task)
  16. .where(
  17. Task.command == queue_name,
  18. Task.status == 0
  19. )
  20. )
  21. tasks = (await db.execute(stmt)).scalars().all()
  22. for task in tasks:
  23. queue_manager.put(
  24. queue_name=queue_name,
  25. task_id=task.id,
  26. priority=0,
  27. )
  28. queue_manager.mark_initialized(queue_name)
  29. logger.info(f"[Queue] rebuilt: {queue_name}")
  30. async def rebuild_vas_task_queue(db: AsyncSession, queue_name: str):
  31. if queue_manager.is_initialized(queue_name):
  32. return
  33. queue = queue_manager.get_queue(queue_name)
  34. queue.clear()
  35. stmt = (
  36. select(VasTask)
  37. .where(
  38. VasTask.routing_key == queue_name,
  39. VasTask.status == 'pending'
  40. )
  41. )
  42. tasks = (await db.execute(stmt)).scalars().all()
  43. for task in tasks:
  44. queue_manager.put(
  45. queue_name=task.routing_key,
  46. task_id=task.id,
  47. priority=task.priority,
  48. )
  49. queue_manager.mark_initialized(queue_name)
  50. logger.info(f"[Queue] rebuilt: {queue_name}")
  51. async def pop_task(db: AsyncSession, queue_name: str):
  52. """
  53. 从指定队列出队一个任务,并标记为 RUNNING
  54. """
  55. await QueueService.rebuild_task_queue(db, queue_name)
  56. task_id = queue_manager.pop(queue_name)
  57. if not task_id:
  58. raise NotFoundError(f'{queue_name} is empty')
  59. stmt = select(Task).where(Task.id == task_id)
  60. task = (await db.execute(stmt)).scalar_one_or_none()
  61. if task.status != 0:
  62. raise BizLogicError(f'task not READY, skipped')
  63. task.status = 1
  64. await db.commit()
  65. await db.refresh(task)
  66. return task
  67. async def pop_vas_task(db: AsyncSession, queue_name: str):
  68. """
  69. 从指定队列出队一个任务,并标记为 RUNNING
  70. """
  71. await QueueService.rebuild_vas_task_queue(db, queue_name)
  72. task_id = queue_manager.pop(queue_name)
  73. if not task_id:
  74. raise NotFoundError(f'{queue_name} is empty')
  75. stmt = select(VasTask).where(VasTask.id == task_id)
  76. task = (await db.execute(stmt)).scalar_one_or_none()
  77. if task.status != "pending":
  78. raise BizLogicError(f'task not READY, skipped')
  79. task.status = "running"
  80. await db.commit()
  81. await db.refresh(task)
  82. return task
  83. async def dump_all():
  84. return queue_manager.dump_all()