from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from redis.asyncio import Redis from datetime import datetime from typing import List from app.core.biz_exception import NotFoundError from app.models.slot_refresh_status import VasSlotRefreshStatus from app.schemas.slot_refresh_status import RefreshBase, RefreshFail, RefreshStatusOut class SlotRefreshStatusService: @staticmethod async def refresh_start( db: AsyncSession, data: RefreshBase ) -> VasSlotRefreshStatus: now = datetime.utcnow() stmt = select(VasSlotRefreshStatus).where( VasSlotRefreshStatus.routing_key == data.routing_key ) result = await db.execute(stmt) record = result.scalar_one_or_none() if record: record.last_refresh_at = now record.snapshot_source = data.snapshot_source record.last_error = None else: record = VasSlotRefreshStatus( routing_key=data.routing_key, snapshot_source=data.snapshot_source, country=data.country, city=data.city, visa_type=data.visa_type, last_refresh_at=now ) db.add(record) await db.commit() return record @staticmethod async def refresh_success( db: AsyncSession, data: RefreshBase ) -> VasSlotRefreshStatus: stmt = select(VasSlotRefreshStatus).where( VasSlotRefreshStatus.routing_key == data.routing_key ) result = await db.execute(stmt) record = result.scalar_one_or_none() if not record: raise NotFoundError(message="refresh record not found") now = datetime.utcnow() record.last_success_at = now record.last_error = None await db.commit() return record @staticmethod async def refresh_fail( db: AsyncSession, data: RefreshFail ) -> VasSlotRefreshStatus: stmt = select(VasSlotRefreshStatus).where( VasSlotRefreshStatus.routing_key == data.routing_key ) result = await db.execute(stmt) record = result.scalar_one_or_none() if not record: raise NotFoundError(message="refresh record not found") record.last_error = data.error await db.commit() return record @staticmethod async def list_all( db: AsyncSession ) -> List[VasSlotRefreshStatus]: stmt = select(VasSlotRefreshStatus) result = await db.execute(stmt) return result.scalars().all()