# app/services/task_service.py from typing import List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update from app.core.queue_manager import queue_manager from app.core.biz_exception import NotFoundError from app.models.task import Task from app.schemas.task import TaskCreate, TaskUpdate class TaskService: # ====================== # 创建任务 # ====================== @staticmethod async def create(db: AsyncSession, obj_in: TaskCreate) -> Task: db_obj = Task( command=obj_in.command, args=obj_in.args, status=obj_in.status or 0, ) db.add(db_obj) await db.commit() await db.refresh(db_obj) # 👇 只在 READY 状态才入队 if db_obj.status == 0: queue_manager.put( queue_name=db_obj.command, task_id=db_obj.id, priority=0, # 先用默认优先级 ) return db_obj # ====================== # 根据 ID 获取 # ====================== @staticmethod async def get_by_id(db: AsyncSession, task_id: int) -> Task: stmt = select(Task).where(Task.id == task_id) obj = (await db.execute(stmt)).scalar_one_or_none() if not obj: raise NotFoundError("Task not exist") return obj # ====================== # 更新任务 # ====================== @staticmethod async def update( db: AsyncSession, task_id: int, obj_in: TaskUpdate ) -> Task: stmt = select(Task).where(Task.id == task_id) db_obj = (await db.execute(stmt)).scalar_one_or_none() if not db_obj: raise NotFoundError("Task not exist") if obj_in.result is not None: db_obj.result = obj_in.result if obj_in.status is not None: db_obj.status = obj_in.status await db.commit() await db.refresh(db_obj) # 👇 核心:从非 READY → READY,才重新入队 if old_status != 0 and db_obj.status == 0: queue_manager.put( queue_name=db_obj.command, task_id=db_obj.id, priority=0, ) return db_obj # ====================== # 获取待处理任务(分页) # ====================== @staticmethod async def get_pending( db: AsyncSession, command: str, page: int, size: int ) -> List[Task]: offset = page * size stmt = ( select(Task) .where( Task.command == command, Task.status == 0 ) .order_by(Task.create_at.asc()) .offset(offset) .limit(size) ) return (await db.execute(stmt)).scalars().all()