# app/services/payment_qr_service.py from typing import List from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.biz_exception import NotFoundError from app.models.payment_provider import VasPaymentProvider from app.models.payment_qr import VasPaymentQR from app.schemas.payment_qr import ( VasPaymentQrCreate, VasPaymentQrSetEnableIn, ) class PaymentQrService: # -------------------------------------------------- # 创建支付二维码 # -------------------------------------------------- @staticmethod async def create( db: AsyncSession, data: VasPaymentQrCreate, ) -> VasPaymentQR: rec = VasPaymentQR(**data.dict()) db.add(rec) await db.commit() await db.refresh(rec) return rec # -------------------------------------------------- # 根据 ID 获取 # -------------------------------------------------- @staticmethod async def get_by_id( db: AsyncSession, qr_id: int, ) -> VasPaymentQR: stmt = select(VasPaymentQR).where( VasPaymentQR.id == qr_id ) obj = (await db.execute(stmt)).scalar_one_or_none() if not obj: raise NotFoundError("QR not exist") return obj # -------------------------------------------------- # 启用 / 禁用 QR # -------------------------------------------------- @staticmethod async def set_enable( db: AsyncSession, qr_id: int, payload: VasPaymentQrSetEnableIn, ) -> VasPaymentQR: stmt = select(VasPaymentQR).where( VasPaymentQR.id == qr_id ) obj = (await db.execute(stmt)).scalar_one_or_none() if not obj: raise NotFoundError("QR not exist") obj.is_active = payload.is_active await db.commit() await db.refresh(obj) return obj # -------------------------------------------------- # 删除 QR # -------------------------------------------------- @staticmethod async def delete( db: AsyncSession, qr_id: int, ) -> bool: stmt = select(VasPaymentQR).where( VasPaymentQR.id == qr_id ) obj = (await db.execute(stmt)).scalar_one_or_none() if not obj: raise NotFoundError("QR not exist") await db.delete(obj) await db.commit() return True # -------------------------------------------------- # 根据设备 ID 查询 # -------------------------------------------------- @staticmethod async def get_by_devid( db: AsyncSession, devid: str, ) -> List[VasPaymentQR]: stmt = select(VasPaymentQR).where( VasPaymentQR.devid == devid ) result = await db.execute(stmt) return result.scalars().all() # -------------------------------------------------- # 根据 provider 名称查询 # -------------------------------------------------- @staticmethod async def get_by_provider( db: AsyncSession, provider: str, ) -> List[VasPaymentQR]: stmt = select(VasPaymentQR).where( VasPaymentQR.provider == provider ) result = await db.execute(stmt) return result.scalars().all() # -------------------------------------------------- # 根据 provider_id 查询 QR(安全校验) # -------------------------------------------------- @staticmethod async def list_by_provider( db: AsyncSession, provider_id: int, ) -> List[VasPaymentQR]: stmt = select(VasPaymentProvider).where( VasPaymentProvider.id == provider_id ) provider = (await db.execute(stmt)).scalar_one_or_none() if not provider: raise NotFoundError("Provider not exist") stmt = select(VasPaymentQR).where( VasPaymentQR.provider == provider.name ) result = await db.execute(stmt) return result.scalars().all()