from typing import Optional, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from app.core.biz_exception import NotFoundError from app.models.troov_session import TroovSession from app.schemas.troov_session import TroovSessionCreate, TroovSessionUpdate # 假设你的项目中分页和搜索工具方法路径如下,请根据实际情况调整 from app.utils.pagination import paginate from app.utils.search import apply_keyword_search_stmt class TroovSessionService: @staticmethod async def add(db: AsyncSession, obj_in: TroovSessionCreate) -> TroovSession: # 优化点 1: 使用 model_dump 直接解包,避免手动挨个赋值 db_obj = TroovSession(**obj_in.model_dump()) # 默认值已经在 schema 中定义了,这里不需要再 obj_in.status or "pending" db.add(db_obj) await db.commit() await db.refresh(db_obj) return db_obj @staticmethod async def pop(db: AsyncSession, slot_date: Optional[str] = None, slot_time: Optional[str] = None) -> TroovSession: stmt = select(TroovSession).where(TroovSession.status == "pending") if slot_date: stmt = stmt.where(TroovSession.slot_date == slot_date) if slot_time: stmt = stmt.where(TroovSession.slot_time == slot_time) # 优点保留: 使用 skip_locked=True 完美解决并发争抢问题 stmt = stmt.order_by(TroovSession.created_at.asc()).limit(1).with_for_update(skip_locked=True) result = await db.execute(stmt) obj = result.scalar_one_or_none() if not obj: raise NotFoundError("No pending troov session found") obj.status = "booking" await db.commit() await db.refresh(obj) return obj @staticmethod async def update(db: AsyncSession, session_id: str, obj_in: TroovSessionUpdate) -> TroovSession: stmt = select(TroovSession).where(TroovSession.session_id == session_id) db_obj = (await db.execute(stmt)).scalar_one_or_none() if not db_obj: raise NotFoundError("TroovSession not found") # 优化点 2: 使用 exclude_unset=True 仅提取客户端真正传入的字段 update_data = obj_in.model_dump(exclude_unset=True) # 安全防御: 防止客户端恶意或错误修改主键 update_data.pop("session_id", None) # 动态赋值,后续 Schema 加字段这里完全不用改代码 for field, value in update_data.items(): setattr(db_obj, field, value) await db.commit() await db.refresh(db_obj) return db_obj @staticmethod async def get_by_session_id(db: AsyncSession, session_id: str) -> Optional[TroovSession]: stmt = select(TroovSession).where(TroovSession.session_id == session_id) return (await db.execute(stmt)).scalar_one_or_none() @staticmethod async def list( db: AsyncSession, status: Optional[str] = None, keyword: Optional[str] = None, # 新增 keyword 支持 page: int = 1, # 分页库通常 page 从 1 开始 size: int = 10 ) -> Any: stmt = select(TroovSession) if status: stmt = stmt.where(TroovSession.status == status) # 优化点 3: 引入你提供的 apply_keyword_search_stmt 范式 if keyword: stmt = apply_keyword_search_stmt( stmt=stmt, model=TroovSession, keyword=keyword, fields=["session_id", "source"] # 根据业务需求定义支持模糊搜索的字段 ) stmt = stmt.order_by(TroovSession.created_at.asc()) # 优化点 4: 使用统一的 paginate 方法取代手动 offset 和 limit return await paginate(db, stmt, page, size)