# app/services/order_service.py import uuid import json from datetime import datetime, timedelta from typing import List, Optional from redis.asyncio import Redis from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.utils.search import apply_keyword_search_stmt from app.utils.pagination import paginate from app.core.biz_exception import NotFoundError, BizLogicError from app.models.user import VasUser from app.models.order import VasOrder from app.models.vas_task import VasTask from app.models.product import VasProduct from app.models.product_routing import VasProductRouting from app.schemas.order import VasOrderCreate, VasOrderPatchUserInputs class OrderService: # -------------------------------------------------- # 管理员强制标记为已支付 # -------------------------------------------------- @staticmethod async def mark_as_admin_paid( db: AsyncSession, order: VasOrder, admin_user: VasUser, ) -> VasOrder: if order.status == "paid": return order order.status = "paid" # ===== user_inputs 安全修复 ===== raw_inputs = order.user_inputs if isinstance(raw_inputs, str): try: order.user_inputs = json.loads(raw_inputs) except Exception: order.user_inputs = {} elif raw_inputs is None or not isinstance(raw_inputs, dict): order.user_inputs = {} order.user_inputs["_admin_bypass"] = { "enabled": True, "by": admin_user.id, "at": datetime.utcnow().isoformat(), "reason": "admin manual order", } db.add(order) await db.commit() await db.refresh(order) return order # -------------------------------------------------- # 为订单创建任务(幂等) # -------------------------------------------------- @staticmethod async def create_tasks_for_order( db: AsyncSession, order: VasOrder ) -> List[VasTask]: if order.status != "paid": return [] stmt = select(VasProductRouting).where( VasProductRouting.product_id == order.product_id, VasProductRouting.is_active == 1 ) result = await db.execute(stmt) routings = result.scalars().all() if not routings: return [] created_tasks: List[VasTask] = [] for routing in routings: exists_stmt = select(VasTask).where( VasTask.order_id == order.id, VasTask.routing_key == routing.routing_key, VasTask.script_version == routing.script_version, ) exists = (await db.execute(exists_stmt)).scalar_one_or_none() if exists: continue task = VasTask( order_id=order.id, routing_key=routing.routing_key, script_version=routing.script_version, priority=routing.priority, status="pending", user_inputs=order.user_inputs, config=routing.config, attempt_count=0, notify_count=0, expire_at=datetime.utcnow() + timedelta(days=7), created_at=datetime.utcnow(), ) db.add(task) created_tasks.append(task) await db.commit() return created_tasks # -------------------------------------------------- # 创建订单 # -------------------------------------------------- @staticmethod async def create( db: AsyncSession, data: VasOrderCreate, product: VasProduct, auth_user: VasUser, redis_client: Redis, ) -> VasOrder: if not auth_user.email: raise BizLogicError( "Your account must be linked to an email address before you can place an order." ) order_id = f"ORD-{datetime.utcnow():%Y%m%d%H%M%S}-{uuid.uuid4().hex[:8]}" order = VasOrder( id=order_id, **data.dict(), product_name=product.title, base_amount=product.price_amount, base_currency=product.price_currency, user_id=auth_user.id, ) db.add(order) await db.commit() await db.refresh(order) return order # -------------------------------------------------- # 获取订单 # -------------------------------------------------- @staticmethod async def get(db: AsyncSession, order_id: str) -> Optional[VasOrder]: stmt = select(VasOrder).where(VasOrder.id == order_id) return (await db.execute(stmt)).scalar_one_or_none() # -------------------------------------------------- # 用户订单列表 # -------------------------------------------------- @staticmethod async def list_by_user( db: AsyncSession, user_id: str, page: int = 0, size: int = 10, keyword: Optional[str] = None, ): stmt = select(VasOrder).where(VasOrder.user_id == user_id) stmt = apply_keyword_search_stmt( stmt=stmt, model=VasOrder, keyword=keyword, fields=["id", "user_id", "product_name"], ).order_by(VasOrder.created_at.desc()) return await paginate(db, stmt, page, size) # -------------------------------------------------- # 管理员订单列表 # -------------------------------------------------- @staticmethod async def list_all( db: AsyncSession, page: int = 0, size: int = 10, keyword: Optional[str] = None, ): stmt = select(VasOrder) query = apply_keyword_search_stmt( stmt=stmt, model=VasOrder, keyword=keyword, fields=["id", "user_id", "user_name", "product_name"], ).order_by(VasOrder.created_at.desc()) return await paginate(db, query, page, size) # -------------------------------------------------- # 更新 user_inputs # -------------------------------------------------- @staticmethod async def patch_user_inputs( db: AsyncSession, order_id: str, payload: VasOrderPatchUserInputs, ) -> VasOrder: stmt = select(VasOrder).where(VasOrder.id == order_id) order = (await db.execute(stmt)).scalar_one_or_none() if not order: raise NotFoundError("Order not exist") order.user_inputs = payload.user_inputs await db.commit() await db.refresh(order) return order