main.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from fastapi import FastAPI, Depends, Request
  2. from fastapi.responses import JSONResponse
  3. from fastapi.middleware.cors import CORSMiddleware
  4. from fastapi.openapi.utils import get_openapi
  5. from fastapi.security import HTTPBearer
  6. from app.api import router
  7. from app.core.auth import verify_token
  8. from app.core.config import settings
  9. from app.core.payment import init_stripe
  10. from app.core.biz_exception import BizException
  11. from app.core.logger import logger
  12. app = FastAPI(title=settings.app_name)
  13. @app.on_event("startup")
  14. def startup():
  15. init_stripe()
  16. @app.exception_handler(BizException)
  17. async def biz_exception_handler(request: Request, exc: BizException):
  18. return JSONResponse(
  19. status_code=exc.http_status,
  20. content={
  21. "code": exc.code,
  22. "message": exc.message,
  23. "data": exc.extra,
  24. },
  25. )
  26. @app.exception_handler(Exception)
  27. async def unhandled_exception_handler(request: Request, exc: Exception):
  28. # ⚠️ 一定要打日志
  29. logger.error("Unhandled exception")
  30. return JSONResponse(
  31. status_code=500,
  32. content={
  33. "code": 50000,
  34. "message": "Internal Server Error",
  35. "data": None,
  36. },
  37. )
  38. # -----------------------
  39. # CORS(可选)
  40. # -----------------------
  41. app.add_middleware(
  42. CORSMiddleware,
  43. allow_origins=["*"],
  44. allow_methods=["*"],
  45. allow_headers=["*"],
  46. )
  47. # -----------------------
  48. # 路由注册
  49. # -----------------------
  50. # 公共路由,不鉴权
  51. app.include_router(
  52. router.public_router,
  53. prefix="/api"
  54. )
  55. # 需要鉴权的路由
  56. app.include_router(
  57. router.protected_router,
  58. prefix="/api",
  59. #dependencies=[Depends(verify_token)]
  60. )
  61. # -----------------------
  62. # Swagger 支持 Bearer Token
  63. # -----------------------
  64. def custom_openapi():
  65. if app.openapi_schema:
  66. return app.openapi_schema
  67. openapi_schema = get_openapi(
  68. title=app.title,
  69. version="1.0.0",
  70. description="API documentation",
  71. routes=app.routes,
  72. )
  73. # 添加全局 Bearer
  74. openapi_schema["components"]["securitySchemes"] = {
  75. "BearerAuth": {
  76. "type": "http",
  77. "scheme": "bearer",
  78. "bearerFormat": "JWT"
  79. }
  80. }
  81. for path in openapi_schema["paths"].values():
  82. for method in path.values():
  83. method["security"] = [{"BearerAuth": []}]
  84. app.openapi_schema = openapi_schema
  85. return app.openapi_schema
  86. app.openapi = custom_openapi