# app/services/task_service.py from typing import List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, update from datetime import datetime from app.core.biz_exception import NotFoundError from app.models.task import Task from app.schemas.task import TaskCreate, TaskUpdate class TaskService: # ====================== # 创建任务 # ====================== @staticmethod async def create(db: AsyncSession, obj_in: TaskCreate) -> Task: db_obj = Task( command=obj_in.command, args=obj_in.args, status=obj_in.status or 0, ) db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj async def pop_task(db: AsyncSession, queue_name: str) -> Task: """ 异步 POP 方法 (MySQL 8.0+ 专用) 原子性获取并锁定一个任务 """ try: # --- 第一步:查询并抢占锁 (SELECT ... FOR UPDATE SKIP LOCKED) --- stmt = ( select(Task) .where(Task.command == queue_name) # 指定队列/命令类型 .where(Task.status == 0) # 只找待执行的 .order_by(Task.id.asc()) # 优先级 > 时间 .limit(1) # MySQL 8.0+ 必须加 skip_locked=True,否则并发高时会造成大量等待 .with_for_update(skip_locked=True) ) result = await db.execute(stmt) task = result.scalar_one_or_none() # --- 第二步:更新状态 (UPDATE) --- if task: task.status = 1 # 标记为执行中 task.update_at = datetime.now() # session.begin() 退出时会自动执行 commit() # 此时数据库中的状态已变更 await db.commit() return task # 如果没抢到任务 raise NotFoundError(message="Task not found") except Exception as e: # 记录日志 # session.begin() 会自动 rollback,无需手动调用 raise e # ====================== # 根据 ID 获取 # ====================== @staticmethod async def get_by_id(db: AsyncSession, task_id: int) -> Task: stmt = select(Task).where(Task.id == task_id) obj = (await db.execute(stmt)).scalar_one_or_none() if not obj: raise NotFoundError("Task not exist") return obj # ====================== # 更新任务 # ====================== @staticmethod async def update( db: AsyncSession, task_id: int, obj_in: TaskUpdate ) -> Task: stmt = select(Task).where(Task.id == task_id) db_obj = (await db.execute(stmt)).scalar_one_or_none() old_status = db_obj.status if not db_obj: raise NotFoundError("Task not exist") if obj_in.result is not None: db_obj.result = obj_in.result if obj_in.status is not None: db_obj.status = obj_in.status await db.commit() await db.refresh(db_obj) return db_obj # ====================== # 获取待处理任务(分页) # ====================== @staticmethod async def get_pending( db: AsyncSession, command: str, page: int, size: int ) -> List[Task]: offset = page * size stmt = ( select(Task) .where( Task.command == command, Task.status == 0 ) .order_by(Task.create_at.asc()) .offset(offset) .limit(size) ) return (await db.execute(stmt)).scalars().all()