| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select
- from app.core.logger import logger
- from app.core.biz_exception import NotFoundError, BizLogicError
- from app.core.queue_manager import queue_manager
- from app.models.task import Task
- from app.models.vas_task import VasTask
- class QueueService:
-
- async def rebuild_task_queue(db: AsyncSession, queue_name: str):
- if queue_manager.is_initialized(queue_name):
- return
-
- queue = queue_manager.get_queue(queue_name)
- queue.clear()
-
- stmt = (
- select(Task)
- .where(
- Task.command == queue_name,
- Task.status == 0
- )
- )
- tasks = (await db.execute(stmt)).scalars().all()
- for task in tasks:
- queue_manager.put(
- queue_name=queue_name,
- task_id=task.id,
- priority=0,
- )
- queue_manager.mark_initialized(queue_name)
- logger.info(f"[Queue] rebuilt: {queue_name}")
-
- async def rebuild_vas_task_queue(db: AsyncSession, queue_name: str):
- if queue_manager.is_initialized(queue_name):
- return
-
- queue = queue_manager.get_queue(queue_name)
- queue.clear()
-
- stmt = (
- select(VasTask)
- .where(
- VasTask.routing_key == queue_name,
- VasTask.status == 'pending'
- )
- )
- tasks = (await db.execute(stmt)).scalars().all()
- for task in tasks:
- queue_manager.put(
- queue_name=task.routing_key,
- task_id=task.id,
- priority=task.priority,
- )
- queue_manager.mark_initialized(queue_name)
- logger.info(f"[Queue] rebuilt: {queue_name}")
-
- async def pop_task(db: AsyncSession, queue_name: str):
- """
- 从指定队列出队一个任务,并标记为 RUNNING
- """
- await QueueService.rebuild_task_queue(db, queue_name)
- task_id = queue_manager.pop(queue_name)
- if not task_id:
- raise NotFoundError(f'{queue_name} is empty')
- stmt = select(Task).where(Task.id == task_id)
- task = (await db.execute(stmt)).scalar_one_or_none()
- if task.status != 0:
- raise BizLogicError(f'task not READY, skipped')
-
- task.status = 1
- await db.commit()
- await db.refresh(task)
- return task
-
-
- async def pop_vas_task(db: AsyncSession, queue_name: str):
- """
- 从指定队列出队一个任务,并标记为 RUNNING
- """
- await QueueService.rebuild_vas_task_queue(db, queue_name)
- task_id = queue_manager.pop(queue_name)
- if not task_id:
- raise NotFoundError(f'{queue_name} is empty')
-
- stmt = select(VasTask).where(VasTask.id == task_id)
- task = (await db.execute(stmt)).scalar_one_or_none()
-
- if task.status != "pending":
- raise BizLogicError(f'task not READY, skipped')
-
- task.status = "running"
-
- await db.commit()
- await db.refresh(task)
- return task
-
-
- async def dump_all():
- return queue_manager.dump_all()
|