jerry 1 mēnesi atpakaļ
vecāks
revīzija
1fe7541aa7

+ 16 - 11
app/main.py

@@ -1,4 +1,4 @@
-import os
+from contextlib import asynccontextmanager
 from fastapi import FastAPI, Depends, Request
 from fastapi.responses import JSONResponse
 from fastapi.middleware.cors import CORSMiddleware
@@ -15,19 +15,24 @@ from app.core.biz_exception import BizException
 from app.core.logger import logger
 
 
-app = FastAPI(title=settings.app_name)
-
-# -----------------------
-# Startup
-# -----------------------
-@app.on_event("startup")
-async def startup():
-    # 支付配置
+@asynccontextmanager
+async def lifespan(app: FastAPI):
     init_stripe()
     logger.info("🟢 Stripe config done")
-    
 
-    # Notification processing is handled by an external daemon.
+    from app.services.scheduler_handlers import scheduler_service, visametric_slot_expire_checker
+    scheduler_service.start()
+    logger.info("🟢 Scheduler started")
+
+    yield
+
+    from app.services.scheduler_service import scheduler_service
+    scheduler_service.stop()
+    logger.info("🟢 Scheduler stopped")
+
+
+app = FastAPI(title=settings.app_name, lifespan=lifespan)
+
 
 # -----------------------
 # Exception Handlers

+ 63 - 0
app/services/scheduler_handlers.py

@@ -0,0 +1,63 @@
+from app.services.scheduler_service import register_scheduled_task, scheduler_service
+from app.core.logger import logger
+
+
+@register_scheduled_task(name="visametric_slot_expire_checker", interval_hours=1)
+async def visametric_slot_expire_checker():
+    """
+    检查 vas_task 中状态为 grabbed 的任务,
+    如果 grabbed_history 中的 slot_date 距离现在不足4天,发送企业微信通知
+    """
+    from app.core.database import AsyncSessionLocal
+    from sqlalchemy import select
+    from datetime import datetime, timedelta
+    from app.models.vas_task import VasTask
+    from app.services.wechat_service import WechatService
+
+    try:
+        expire_days = 4
+        now = datetime.utcnow()
+        threshold_date = (now + timedelta(days=expire_days)).strftime("%Y-%m-%d")
+
+        async with AsyncSessionLocal() as db:
+            stmt = select(VasTask).where(
+                VasTask.routing_key == "auto.slot.dub.de.tourist",
+                VasTask.status == "grabbed"
+            )
+            result = await db.execute(stmt)
+            tasks = result.scalars().all()
+
+            expiring_tasks = []
+            for task in tasks:
+                grabbed_history = task.grabbed_history or {}
+                slot_date_str = grabbed_history.get("slot_date", "")
+                if not slot_date_str:
+                    continue
+
+                try:
+                    try:
+                        slot_date = datetime.strptime(slot_date_str, "%Y-%m-%d")
+                    except ValueError:
+                        slot_date = datetime.strptime(slot_date_str, "%d/%m/%Y")
+                    if slot_date <= now + timedelta(days=expire_days):
+                        expiring_tasks.append({
+                            "order_id": task.order_id,
+                            "slot_date": slot_date_str,
+                            "slot_time": grabbed_history.get("slot_time", "N/A"),
+                        })
+                except ValueError:
+                    logger.warning(f"[Scheduler] Invalid slot_date format: {slot_date_str}")
+
+            if expiring_tasks:
+                lines = [f"⚠️ 以下德签 slot 即将到期(不足 {expire_days} 天):"]
+                for t in expiring_tasks:
+                    lines.append(f"- 订单号: {t['order_id']}, Slot: {t['slot_date']} {t['slot_time']}")
+
+                content = "\n".join(lines)
+                await WechatService.push_markdown_no_token(content)
+                logger.info(f"[Scheduler] Sent WeChat notification for {len(expiring_tasks)} expiring tasks")
+            else:
+                logger.info("[Scheduler] No expiring tasks found")
+
+    except Exception as e:
+        logger.error(f"[Scheduler] Error in visametric_slot_expire_checker: {e}")

+ 97 - 0
app/services/scheduler_service.py

@@ -0,0 +1,97 @@
+import asyncio
+import threading
+from typing import Callable, List, Dict, Any
+from apscheduler.schedulers.asyncio import AsyncIOScheduler
+from apscheduler.triggers.interval import IntervalTrigger
+from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
+from app.core.logger import logger
+
+
+class ScheduledTask:
+    def __init__(self, name: str, handler: Callable, interval_hours: int = 1):
+        self.name = name
+        self.handler = handler
+        self.interval_hours = interval_hours
+
+
+class SchedulerService:
+    _instance = None
+    _lock = threading.Lock()
+    _scheduler: AsyncIOScheduler = None
+    _tasks: List[ScheduledTask] = []
+    _started = False
+
+    @classmethod
+    def get_instance(cls) -> "SchedulerService":
+        if cls._instance is None:
+            with cls._lock:
+                if cls._instance is None:
+                    cls._instance = cls()
+        return cls._instance
+
+    def __init__(self):
+        self._scheduler = AsyncIOScheduler(timezone="UTC")
+        self._setup_event_listeners()
+
+    def _setup_event_listeners(self):
+        def on_job_error(event):
+            logger.error(f"[Scheduler] Job {event.job_id} error: {event.exception}")
+
+        def on_job_executed(event):
+            if event.exception is None:
+                logger.info(f"[Scheduler] Job {event.job_id} completed successfully")
+            else:
+                logger.error(f"[Scheduler] Job {event.job_id} failed: {event.exception}")
+
+        self._scheduler.add_listener(on_job_error, EVENT_JOB_ERROR)
+        self._scheduler.add_listener(on_job_executed, EVENT_JOB_EXECUTED)
+
+    def register_task(self, name: str, handler: Callable, interval_hours: int = 1):
+        task = ScheduledTask(name=name, handler=handler, interval_hours=interval_hours)
+        self._tasks.append(task)
+        logger.info(f"[Scheduler] Registered task: {name} (interval: {interval_hours}h)")
+
+    def start(self):
+        if self._started:
+            logger.warning("[Scheduler] Already started")
+            return
+
+        for task in self._tasks:
+            self._scheduler.add_job(
+                task.handler,
+                IntervalTrigger(hours=task.interval_hours),
+                id=task.name,
+                replace_existing=True,
+                misfire_grace_time=3600,
+            )
+            logger.info(f"[Scheduler] Scheduled task: {task.name} every {task.interval_hours}h")
+
+        self._scheduler.start()
+        self._started = True
+        logger.info("[Scheduler] Started successfully")
+
+    def stop(self):
+        if self._scheduler.running:
+            self._scheduler.shutdown(wait=False)
+            self._started = False
+            logger.info("[Scheduler] Stopped")
+
+    def get_jobs(self) -> List[Dict[str, Any]]:
+        jobs = []
+        for job in self._scheduler.get_jobs():
+            jobs.append({
+                "id": job.id,
+                "next_run": str(job.next_run_time) if job.next_run_time else None,
+                "interval_hours": job.trigger.interval_hours if hasattr(job.trigger, 'interval_hours') else None,
+            })
+        return jobs
+
+
+scheduler_service = SchedulerService.get_instance()
+
+
+def register_scheduled_task(name: str, interval_hours: int = 1):
+    def decorator(func: Callable):
+        scheduler_service.register_task(name, func, interval_hours)
+        return func
+    return decorator

+ 1 - 1
app/services/troov_session_service.py

@@ -94,7 +94,7 @@ class TroovSessionService:
                 fields=["session_id", "source"] # 根据业务需求定义支持模糊搜索的字段
             )
 
-        stmt = stmt.order_by(TroovSession.created_at.desc())
+        stmt = stmt.order_by(TroovSession.created_at.asc())
         
         # 优化点 4: 使用统一的 paginate 方法取代手动 offset 和 limit
         return await paginate(db, stmt, page, size)