| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- # app/services/task_service.py
- from typing import List, Optional
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select, update
- from app.core.biz_exception import NotFoundError
- from app.models.task import Task
- from app.schemas.task import TaskCreate, TaskUpdate
- class TaskService:
- # ======================
- # 创建任务
- # ======================
- @staticmethod
- async def create(db: AsyncSession, obj_in: TaskCreate) -> Task:
- db_obj = Task(
- command=obj_in.command,
- args=obj_in.args,
- status=obj_in.status or 0,
- )
- db.add(db_obj)
- await db.commit()
- await db.refresh(db_obj)
- return db_obj
-
- async def pop_task(db: AsyncSession, queue_name: str) -> Task:
- """
- 异步 POP 方法 (MySQL 8.0+ 专用)
- 原子性获取并锁定一个任务
- """
- try:
- # --- 第一步:查询并抢占锁 (SELECT ... FOR UPDATE SKIP LOCKED) ---
- stmt = (
- select(Task)
- .where(Task.command == queue_name) # 指定队列/命令类型
- .where(Task.status == 0) # 只找待执行的
- .order_by(Task.id.asc()) # 优先级 > 时间
- .limit(1)
- # MySQL 8.0+ 必须加 skip_locked=True,否则并发高时会造成大量等待
- .with_for_update(skip_locked=True)
- )
-
- result = await db.execute(stmt)
- task = result.scalar_one_or_none()
-
- # --- 第二步:更新状态 (UPDATE) ---
- if task:
- task.status = 1 # 标记为执行中
- task.update_at = datetime.now()
-
- # session.begin() 退出时会自动执行 commit()
- # 此时数据库中的状态已变更
- await db.commit()
- return task
-
- # 如果没抢到任务
- raise NotFoundError(message="Task not found")
- except Exception as e:
- # 记录日志
- # session.begin() 会自动 rollback,无需手动调用
- raise e
- # ======================
- # 根据 ID 获取
- # ======================
- @staticmethod
- async def get_by_id(db: AsyncSession, task_id: int) -> Task:
- stmt = select(Task).where(Task.id == task_id)
- obj = (await db.execute(stmt)).scalar_one_or_none()
- if not obj:
- raise NotFoundError("Task not exist")
- return obj
- # ======================
- # 更新任务
- # ======================
- @staticmethod
- async def update(
- db: AsyncSession,
- task_id: int,
- obj_in: TaskUpdate
- ) -> Task:
- stmt = select(Task).where(Task.id == task_id)
- db_obj = (await db.execute(stmt)).scalar_one_or_none()
- old_status = db_obj.status
- if not db_obj:
- raise NotFoundError("Task not exist")
- if obj_in.result is not None:
- db_obj.result = obj_in.result
- if obj_in.status is not None:
- db_obj.status = obj_in.status
- await db.commit()
- await db.refresh(db_obj)
- return db_obj
- # ======================
- # 获取待处理任务(分页)
- # ======================
- @staticmethod
- async def get_pending(
- db: AsyncSession,
- command: str,
- page: int,
- size: int
- ) -> List[Task]:
- offset = page * size
- stmt = (
- select(Task)
- .where(
- Task.command == command,
- Task.status == 0
- )
- .order_by(Task.create_at.asc())
- .offset(offset)
- .limit(size)
- )
- return (await db.execute(stmt)).scalars().all()
|