troov_session_service.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from typing import Optional, Any
  2. from sqlalchemy.ext.asyncio import AsyncSession
  3. from sqlalchemy import select
  4. from app.core.biz_exception import NotFoundError
  5. from app.models.troov_session import TroovSession
  6. from app.schemas.troov_session import TroovSessionCreate, TroovSessionUpdate
  7. # 假设你的项目中分页和搜索工具方法路径如下,请根据实际情况调整
  8. from app.utils.pagination import paginate
  9. from app.utils.search import apply_keyword_search_stmt
  10. class TroovSessionService:
  11. @staticmethod
  12. async def add(db: AsyncSession, obj_in: TroovSessionCreate) -> TroovSession:
  13. # 优化点 1: 使用 model_dump 直接解包,避免手动挨个赋值
  14. db_obj = TroovSession(**obj_in.model_dump())
  15. # 默认值已经在 schema 中定义了,这里不需要再 obj_in.status or "pending"
  16. db.add(db_obj)
  17. await db.commit()
  18. await db.refresh(db_obj)
  19. return db_obj
  20. @staticmethod
  21. async def pop(db: AsyncSession, slot_date: Optional[str] = None, slot_time: Optional[str] = None) -> TroovSession:
  22. stmt = select(TroovSession).where(TroovSession.status == "pending")
  23. if slot_date:
  24. stmt = stmt.where(TroovSession.slot_date == slot_date)
  25. if slot_time:
  26. stmt = stmt.where(TroovSession.slot_time == slot_time)
  27. # 优点保留: 使用 skip_locked=True 完美解决并发争抢问题
  28. stmt = stmt.order_by(TroovSession.created_at.asc()).limit(1).with_for_update(skip_locked=True)
  29. result = await db.execute(stmt)
  30. obj = result.scalar_one_or_none()
  31. if not obj:
  32. raise NotFoundError("No pending troov session found")
  33. obj.status = "booking"
  34. await db.commit()
  35. await db.refresh(obj)
  36. return obj
  37. @staticmethod
  38. async def update(db: AsyncSession, session_id: str, obj_in: TroovSessionUpdate) -> TroovSession:
  39. stmt = select(TroovSession).where(TroovSession.session_id == session_id)
  40. db_obj = (await db.execute(stmt)).scalar_one_or_none()
  41. if not db_obj:
  42. raise NotFoundError("TroovSession not found")
  43. # 优化点 2: 使用 exclude_unset=True 仅提取客户端真正传入的字段
  44. update_data = obj_in.model_dump(exclude_unset=True)
  45. # 安全防御: 防止客户端恶意或错误修改主键
  46. update_data.pop("session_id", None)
  47. # 动态赋值,后续 Schema 加字段这里完全不用改代码
  48. for field, value in update_data.items():
  49. setattr(db_obj, field, value)
  50. await db.commit()
  51. await db.refresh(db_obj)
  52. return db_obj
  53. @staticmethod
  54. async def get_by_session_id(db: AsyncSession, session_id: str) -> Optional[TroovSession]:
  55. stmt = select(TroovSession).where(TroovSession.session_id == session_id)
  56. return (await db.execute(stmt)).scalar_one_or_none()
  57. @staticmethod
  58. async def list(
  59. db: AsyncSession,
  60. status: Optional[str] = None,
  61. keyword: Optional[str] = None, # 新增 keyword 支持
  62. page: int = 1, # 分页库通常 page 从 1 开始
  63. size: int = 10
  64. ) -> Any:
  65. stmt = select(TroovSession)
  66. if status:
  67. stmt = stmt.where(TroovSession.status == status)
  68. # 优化点 3: 引入你提供的 apply_keyword_search_stmt 范式
  69. if keyword:
  70. stmt = apply_keyword_search_stmt(
  71. stmt=stmt,
  72. model=TroovSession,
  73. keyword=keyword,
  74. fields=["session_id", "source"] # 根据业务需求定义支持模糊搜索的字段
  75. )
  76. stmt = stmt.order_by(TroovSession.created_at.desc())
  77. # 优化点 4: 使用统一的 paginate 方法取代手动 offset 和 limit
  78. return await paginate(db, stmt, page, size)