database.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from sqlalchemy.ext.asyncio import (
  2. create_async_engine,
  3. async_sessionmaker,
  4. AsyncSession,
  5. )
  6. from sqlalchemy.orm import declarative_base
  7. from app.core.config import settings
  8. # =========================
  9. # Async Engine
  10. # =========================
  11. engine = create_async_engine(
  12. settings.database_url, # ⚠️ 必须是 async URL
  13. echo=settings.debug,
  14. pool_pre_ping=True,
  15. pool_recycle=1800,
  16. )
  17. # =========================
  18. # Async Session 工厂
  19. # =========================
  20. AsyncSessionLocal = async_sessionmaker(
  21. bind=engine,
  22. class_=AsyncSession,
  23. autoflush=False,
  24. expire_on_commit=False,
  25. )
  26. # ORM 基类
  27. Base = declarative_base()
  28. # =========================
  29. # FastAPI 依赖
  30. # =========================
  31. async def get_db() -> AsyncSession:
  32. async with AsyncSessionLocal() as session:
  33. try:
  34. yield session
  35. except Exception:
  36. # --- 核心改进:全局兜底回滚 ---
  37. await session.rollback()
  38. raise
  39. finally:
  40. await session.close()