| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100 |
- from typing import Optional, Any
- from sqlalchemy.ext.asyncio import AsyncSession
- from sqlalchemy import select
- from app.core.biz_exception import NotFoundError
- from app.models.troov_session import TroovSession
- from app.schemas.troov_session import TroovSessionCreate, TroovSessionUpdate
- # 假设你的项目中分页和搜索工具方法路径如下,请根据实际情况调整
- from app.utils.pagination import paginate
- from app.utils.search import apply_keyword_search_stmt
- class TroovSessionService:
- @staticmethod
- async def add(db: AsyncSession, obj_in: TroovSessionCreate) -> TroovSession:
- # 优化点 1: 使用 model_dump 直接解包,避免手动挨个赋值
- db_obj = TroovSession(**obj_in.model_dump())
-
- # 默认值已经在 schema 中定义了,这里不需要再 obj_in.status or "pending"
- db.add(db_obj)
- await db.commit()
- await db.refresh(db_obj)
- return db_obj
- @staticmethod
- async def pop(db: AsyncSession, slot_date: Optional[str] = None, slot_time: Optional[str] = None) -> TroovSession:
- stmt = select(TroovSession).where(TroovSession.status == "pending")
- if slot_date:
- stmt = stmt.where(TroovSession.slot_date == slot_date)
- if slot_time:
- stmt = stmt.where(TroovSession.slot_time == slot_time)
-
- # 优点保留: 使用 skip_locked=True 完美解决并发争抢问题
- stmt = stmt.order_by(TroovSession.created_at.asc()).limit(1).with_for_update(skip_locked=True)
- result = await db.execute(stmt)
- obj = result.scalar_one_or_none()
- if not obj:
- raise NotFoundError("No pending troov session found")
- obj.status = "booking"
- await db.commit()
- await db.refresh(obj)
- return obj
- @staticmethod
- async def update(db: AsyncSession, session_id: str, obj_in: TroovSessionUpdate) -> TroovSession:
- stmt = select(TroovSession).where(TroovSession.session_id == session_id)
- db_obj = (await db.execute(stmt)).scalar_one_or_none()
- if not db_obj:
- raise NotFoundError("TroovSession not found")
- # 优化点 2: 使用 exclude_unset=True 仅提取客户端真正传入的字段
- update_data = obj_in.model_dump(exclude_unset=True)
-
- # 安全防御: 防止客户端恶意或错误修改主键
- update_data.pop("session_id", None)
- # 动态赋值,后续 Schema 加字段这里完全不用改代码
- for field, value in update_data.items():
- setattr(db_obj, field, value)
- await db.commit()
- await db.refresh(db_obj)
- return db_obj
- @staticmethod
- async def get_by_session_id(db: AsyncSession, session_id: str) -> Optional[TroovSession]:
- stmt = select(TroovSession).where(TroovSession.session_id == session_id)
- return (await db.execute(stmt)).scalar_one_or_none()
- @staticmethod
- async def list(
- db: AsyncSession,
- status: Optional[str] = None,
- keyword: Optional[str] = None, # 新增 keyword 支持
- page: int = 1, # 分页库通常 page 从 1 开始
- size: int = 10
- ) -> Any:
- stmt = select(TroovSession)
-
- if status:
- stmt = stmt.where(TroovSession.status == status)
- # 优化点 3: 引入你提供的 apply_keyword_search_stmt 范式
- if keyword:
- stmt = apply_keyword_search_stmt(
- stmt=stmt,
- model=TroovSession,
- keyword=keyword,
- fields=["session_id", "source"] # 根据业务需求定义支持模糊搜索的字段
- )
- stmt = stmt.order_by(TroovSession.created_at.asc())
-
- # 优化点 4: 使用统一的 paginate 方法取代手动 offset 和 limit
- return await paginate(db, stmt, page, size)
|