|
|
@@ -4,7 +4,7 @@ from datetime import datetime, date, timedelta
|
|
|
from typing import List, Optional
|
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
-from sqlalchemy import select, or_, and_
|
|
|
+from sqlalchemy import func, text, select, or_, and_
|
|
|
from redis.asyncio import Redis # 引入 Redis 类型
|
|
|
|
|
|
|
|
|
@@ -32,62 +32,34 @@ class VasTaskService:
|
|
|
return rec
|
|
|
|
|
|
@staticmethod
|
|
|
- async def pop_vas_task(session: AsyncSession, routing_key: str, cooldown_seconds: int = 60):
|
|
|
- """
|
|
|
- 异步获取任务,支持冷却期机制。
|
|
|
-
|
|
|
- :param session: 数据库异步会话
|
|
|
- :param routing_key: 队列键值
|
|
|
- :param cooldown_seconds: 失败后的冷却时间(秒),默认60秒
|
|
|
- :return: VasTask 对象 or None
|
|
|
- """
|
|
|
- # 计算冷却截止时间:当前时间 - 冷却秒数
|
|
|
- # 只有 updated_at 早于这个时间的重试任务,才会被提取
|
|
|
- cutoff_time = datetime.utcnow() - timedelta(seconds=cooldown_seconds)
|
|
|
-
|
|
|
- try:
|
|
|
- # --- 构造查询语句 ---
|
|
|
- stmt = (
|
|
|
- select(VasTask)
|
|
|
- .where(VasTask.routing_key == routing_key)
|
|
|
- .where(VasTask.status == 'pending')
|
|
|
- # === 核心逻辑:冷却期筛选 ===
|
|
|
- .where(
|
|
|
- or_(
|
|
|
- # 情况1:这是一个全新任务 (从未尝试过)
|
|
|
- VasTask.attempt_count == 0,
|
|
|
-
|
|
|
- # 情况2:这是一个重试任务,且距离上次更新(失败)已经过了冷却期
|
|
|
- and_(
|
|
|
- VasTask.attempt_count > 0,
|
|
|
- VasTask.updated_at < cutoff_time
|
|
|
- )
|
|
|
+ async def pop_vas_task(db: AsyncSession, routing_key: str, cooldown_seconds: int = 60):
|
|
|
+ stmt = (
|
|
|
+ select(VasTask)
|
|
|
+ .where(VasTask.routing_key == routing_key)
|
|
|
+ .where(VasTask.status == 'pending')
|
|
|
+ .where(
|
|
|
+ or_(
|
|
|
+ VasTask.attempt_count == 0,
|
|
|
+ and_(
|
|
|
+ VasTask.attempt_count > 0,
|
|
|
+ VasTask.updated_at < (func.utc_timestamp() - text(f"INTERVAL {cooldown_seconds} SECOND"))
|
|
|
)
|
|
|
)
|
|
|
- # 排序:优先级优先(0假设是最高优先级?),其次是先创建的
|
|
|
- # 注意:根据你的业务,priority 可能 desc 才是高优先级,这里按 asc 写
|
|
|
- .order_by(VasTask.priority.desc(), VasTask.id.asc())
|
|
|
- .limit(1)
|
|
|
- # MySQL 8.0+ 必加,跳过被锁定的行
|
|
|
- .with_for_update(skip_locked=True)
|
|
|
)
|
|
|
-
|
|
|
- result = await session.execute(stmt)
|
|
|
- task = result.scalar_one_or_none()
|
|
|
-
|
|
|
- # --- 更新状态 ---
|
|
|
- if task:
|
|
|
- task.status = 'running' # 标记为已被抓取
|
|
|
- task.attempt_count += 1 # 增加尝试次数
|
|
|
- task.updated_at = datetime.utcnow() # 更新时间(重置冷却计时起点)
|
|
|
-
|
|
|
- await session.commit()
|
|
|
- # session.begin() 结束时自动 commit
|
|
|
- return task
|
|
|
+ .order_by(VasTask.priority.desc(), VasTask.id.asc())
|
|
|
+ .limit(1)
|
|
|
+ .with_for_update(skip_locked=True)
|
|
|
+ )
|
|
|
+ result = await db.execute(stmt)
|
|
|
+ task = result.scalar_one_or_none()
|
|
|
+ if not task:
|
|
|
raise NotFoundError(message="Task not found")
|
|
|
- except Exception as e:
|
|
|
- # 记录日志
|
|
|
- raise e
|
|
|
+ task.status = 'running'
|
|
|
+ task.attempt_count += 1
|
|
|
+ task.updated_at = func.utc_timestamp()
|
|
|
+ await db.commit()
|
|
|
+ await db.refresh(task)
|
|
|
+ return task
|
|
|
|
|
|
@staticmethod
|
|
|
async def get_expiring_tasks(db: AsyncSession, threshold_days: int = 3):
|
|
|
@@ -237,9 +209,6 @@ class VasTaskService:
|
|
|
if not rec:
|
|
|
raise NotFoundError("Task not exist")
|
|
|
|
|
|
- if rec.status == "pending":
|
|
|
- raise BizLogicError("Task is in queue already")
|
|
|
-
|
|
|
rec.status = "pending"
|
|
|
await db.commit()
|
|
|
await db.refresh(rec)
|