product_service.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # app/services/product_service.py
  2. from typing import Optional
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select
  5. from app.utils.search import apply_keyword_search_stmt
  6. from app.utils.pagination import paginate
  7. from app.core.biz_exception import NotFoundError
  8. from app.models.product import VasProduct
  9. from app.schemas.product import VasProductCreate, VasProductUpdate
  10. class ProductService:
  11. @staticmethod
  12. async def create(
  13. db: AsyncSession,
  14. data: VasProductCreate,
  15. ) -> VasProduct:
  16. rec = VasProduct(**data.dict())
  17. db.add(rec)
  18. await db.commit()
  19. await db.refresh(rec)
  20. return rec
  21. @staticmethod
  22. async def get(
  23. db: AsyncSession,
  24. id: int,
  25. ) -> VasProduct:
  26. stmt = select(VasProduct).where(VasProduct.id == id)
  27. obj = (await db.execute(stmt)).scalar_one_or_none()
  28. if not obj:
  29. raise NotFoundError("Product not exist")
  30. return obj
  31. @staticmethod
  32. async def update(
  33. db: AsyncSession,
  34. id: int,
  35. data: VasProductUpdate,
  36. ) -> VasProduct:
  37. stmt = select(VasProduct).where(VasProduct.id == id)
  38. rec = (await db.execute(stmt)).scalar_one_or_none()
  39. if not rec:
  40. raise NotFoundError("Product not exist")
  41. for k, v in data.dict(exclude_unset=True).items():
  42. setattr(rec, k, v)
  43. await db.commit()
  44. await db.refresh(rec)
  45. return rec
  46. @staticmethod
  47. async def list_enable_product(
  48. db: AsyncSession,
  49. country: str = None,
  50. visa_type: str = None,
  51. page: int = 0,
  52. size: int = 10,
  53. keyword: str = None,
  54. ):
  55. # ⚠️ paginate / apply_keyword_search 仍然基于 Query
  56. # 如果你当前 paginate 是同步实现,这里保持与你原项目一致
  57. stmt = select(VasProduct)
  58. stmt = stmt.where(VasProduct.enabled == 1)
  59. if country:
  60. stmt = stmt.where(VasProduct.country == country)
  61. if visa_type:
  62. stmt = stmt.where(VasProduct.visa_type == visa_type)
  63. stmt = apply_keyword_search_stmt(
  64. stmt=stmt,
  65. model=VasProduct,
  66. keyword=keyword,
  67. fields=["title", "provider", "description", "country", "city"],
  68. )
  69. stmt = stmt.order_by(
  70. VasProduct.recommend_score.desc(),
  71. VasProduct.created_at.desc(),
  72. )
  73. return await paginate(db, stmt, page, size)
  74. @staticmethod
  75. async def list_product(
  76. db: AsyncSession,
  77. country: str = None,
  78. visa_type: str = None,
  79. page: int = 0,
  80. size: int = 10,
  81. keyword: str = None,
  82. ):
  83. # ⚠️ paginate / apply_keyword_search 仍然基于 Query
  84. # 如果你当前 paginate 是同步实现,这里保持与你原项目一致
  85. stmt = select(VasProduct)
  86. if country:
  87. stmt = stmt.where(VasProduct.country == country)
  88. if visa_type:
  89. stmt = stmt.where(VasProduct.visa_type == visa_type)
  90. stmt = apply_keyword_search_stmt(
  91. stmt=stmt,
  92. model=VasProduct,
  93. keyword=keyword,
  94. fields=["title", "provider", "description", "country", "city"],
  95. )
  96. stmt = stmt.order_by(
  97. VasProduct.recommend_score.desc(),
  98. VasProduct.created_at.desc(),
  99. )
  100. return await paginate(db, stmt, page, size)