main.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. from fastapi import FastAPI, Depends, Request
  2. from fastapi.responses import JSONResponse
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from fastapi.openapi.utils import get_openapi
  5. from fastapi.security import HTTPBearer
  6. from app.api import router
  7. from app.core.auth import RoleLevel, require_min_role
  8. from app.core.config import settings
  9. from app.core.payment import init_stripe
  10. from app.core.biz_exception import BizException
  11. from app.core.logger import logger
  12. app = FastAPI(title=settings.app_name)
  13. @app.on_event("startup")
  14. def startup():
  15. init_stripe()
  16. @app.exception_handler(BizException)
  17. async def biz_exception_handler(request: Request, exc: BizException):
  18. return JSONResponse(
  19. status_code=exc.http_status,
  20. content={
  21. "code": exc.code,
  22. "message": exc.message,
  23. "data": exc.extra,
  24. },
  25. )
  26. @app.exception_handler(Exception)
  27. async def unhandled_exception_handler(request: Request, exc: Exception):
  28. # ⚠️ 一定要打日志
  29. logger.error("Unhandled exception")
  30. return JSONResponse(
  31. status_code=500,
  32. content={
  33. "code": 50000,
  34. "message": "Internal Server Error",
  35. "data": None,
  36. },
  37. )
  38. # -----------------------
  39. # CORS(可选)
  40. # -----------------------
  41. app.add_middleware(
  42. CORSMiddleware,
  43. allow_origins=["*"],
  44. allow_methods=["*"],
  45. allow_headers=["*"],
  46. )
  47. # -----------------------
  48. # 路由注册
  49. # -----------------------
  50. # 公共路由,不鉴权
  51. app.include_router(
  52. router.public_router,
  53. prefix="/api"
  54. )
  55. # 需要鉴权的路由
  56. app.include_router(
  57. router.protected_router,
  58. prefix="/api",
  59. dependencies=[Depends(require_min_role(RoleLevel.user))]
  60. )
  61. # 需要管理员权限的路由
  62. app.include_router(
  63. router.admin_required_router,
  64. prefix="/api",
  65. dependencies=[Depends(require_min_role(RoleLevel.admin))]
  66. )
  67. # -----------------------
  68. # Swagger 支持 Bearer Token
  69. # -----------------------
  70. def custom_openapi():
  71. if app.openapi_schema:
  72. return app.openapi_schema
  73. openapi_schema = get_openapi(
  74. title=app.title,
  75. version="1.0.0",
  76. description="API documentation",
  77. routes=app.routes,
  78. )
  79. # 添加全局 Bearer
  80. openapi_schema["components"]["securitySchemes"] = {
  81. "BearerAuth": {
  82. "type": "http",
  83. "scheme": "bearer",
  84. "bearerFormat": "JWT"
  85. }
  86. }
  87. for path in openapi_schema["paths"].values():
  88. for method in path.values():
  89. method["security"] = [{"BearerAuth": []}]
  90. app.openapi_schema = openapi_schema
  91. return app.openapi_schema
  92. app.openapi = custom_openapi