from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.logger import logger from app.core.biz_exception import NotFoundError, BizLogicError from app.core.queue_manager import queue_manager from app.models.task import Task from app.models.vas_task import VasTask class QueueService: async def rebuild_task_queue(db: AsyncSession, queue_name: str): if queue_manager.is_initialized(queue_name): return queue = queue_manager.get_queue(queue_name) queue.clear() stmt = ( select(Task) .where( Task.command == queue_name, Task.status == 0 ) ) tasks = (await db.execute(stmt)).scalars().all() for task in tasks: queue_manager.put( queue_name=queue_name, task_id=task.id, priority=0, ) queue_manager.mark_initialized(queue_name) logger.info(f"[Queue] rebuilt: {queue_name}") async def rebuild_vas_task_queue(db: AsyncSession, queue_name: str): if queue_manager.is_initialized(queue_name): return queue = queue_manager.get_queue(queue_name) queue.clear() stmt = ( select(VasTask) .where( VasTask.routing_key == queue_name, VasTask.status == 'pending' ) ) tasks = (await db.execute(stmt)).scalars().all() for task in tasks: queue_manager.put( queue_name=task.routing_key, task_id=task.id, priority=task.priority, ) queue_manager.mark_initialized(queue_name) logger.info(f"[Queue] rebuilt: {queue_name}") async def pop_task(db: AsyncSession, queue_name: str): """ 从指定队列出队一个任务,并标记为 RUNNING """ await QueueService.rebuild_task_queue(db, queue_name) task_id = queue_manager.pop(queue_name) if not task_id: raise NotFoundError(f'{queue_name} is empty') stmt = select(Task).where(Task.id == task_id) task = (await db.execute(stmt)).scalar_one_or_none() if task.status != 0: raise BizLogicError(f'task not READY, skipped') task.status = 1 await db.commit() await db.refresh(task) return task async def pop_vas_task(db: AsyncSession, queue_name: str): """ 从指定队列出队一个任务,并标记为 RUNNING """ await QueueService.rebuild_vas_task_queue(db, queue_name) task_id = queue_manager.pop(queue_name) if not task_id: raise NotFoundError(f'{queue_name} is empty') stmt = select(VasTask).where(VasTask.id == task_id) task = (await db.execute(stmt)).scalar_one_or_none() if task.status != "pending": raise BizLogicError(f'task not READY, skipped') task.status = "running" await db.commit() await db.refresh(task) return task async def dump_all(): return queue_manager.dump_all()