schema_service.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # app/services/schema_service.py
  2. from typing import List
  3. from sqlalchemy.ext.asyncio import AsyncSession
  4. from sqlalchemy import select
  5. from app.core.biz_exception import NotFoundError
  6. from app.models.schema import VasSchema
  7. from app.schemas.schema import VasSchemaCreate, VasSchemaUpdate
  8. class SchemaService:
  9. @staticmethod
  10. async def create(
  11. db: AsyncSession,
  12. data: VasSchemaCreate,
  13. ) -> VasSchema:
  14. rec = VasSchema(**data.model_dump(by_alias=True))
  15. db.add(rec)
  16. await db.commit()
  17. await db.refresh(rec)
  18. return rec
  19. @staticmethod
  20. async def get(
  21. db: AsyncSession,
  22. id: int,
  23. ) -> VasSchema:
  24. stmt = select(VasSchema).where(VasSchema.id == id)
  25. obj = (await db.execute(stmt)).scalar_one_or_none()
  26. if not obj:
  27. raise NotFoundError("Schema not exist")
  28. return obj
  29. @staticmethod
  30. async def update(
  31. db: AsyncSession,
  32. id: int,
  33. data: VasSchemaUpdate,
  34. ) -> VasSchema:
  35. stmt = select(VasSchema).where(VasSchema.id == id)
  36. obj = (await db.execute(stmt)).scalar_one_or_none()
  37. if not obj:
  38. raise NotFoundError("Schema not exist")
  39. update_data = data.model_dump(exclude_unset=True, by_alias=True)
  40. for k, v in update_data.items():
  41. setattr(obj, k, v)
  42. await db.commit()
  43. await db.refresh(obj)
  44. return obj
  45. @staticmethod
  46. async def delete(
  47. db: AsyncSession,
  48. id: int,
  49. ) -> None:
  50. stmt = select(VasSchema).where(VasSchema.id == id)
  51. obj = (await db.execute(stmt)).scalar_one_or_none()
  52. if not obj:
  53. raise NotFoundError("Schema not exist")
  54. await db.delete(obj)
  55. await db.commit()
  56. @staticmethod
  57. async def list_all(
  58. db: AsyncSession,
  59. ) -> List[VasSchema]:
  60. stmt = select(VasSchema)
  61. result = await db.execute(stmt)
  62. return result.scalars().all()