main.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from contextlib import asynccontextmanager
  2. from fastapi import FastAPI, Depends, Request
  3. from fastapi.responses import JSONResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from fastapi.openapi.utils import get_openapi
  6. from fastapi.exceptions import RequestValidationError
  7. from starlette.exceptions import HTTPException as StarletteHTTPException
  8. from app.api import router
  9. from app.core.database import AsyncSessionLocal
  10. from app.core.auth import RoleLevel, require_min_role
  11. from app.core.config import settings
  12. from app.core.payment import init_stripe
  13. from app.core.biz_exception import BizException
  14. from app.core.logger import logger
  15. @asynccontextmanager
  16. async def lifespan(app: FastAPI):
  17. init_stripe()
  18. logger.info("🟢 Stripe config done")
  19. yield
  20. app = FastAPI(title=settings.app_name, lifespan=lifespan)
  21. # -----------------------
  22. # Exception Handlers
  23. # -----------------------
  24. @app.exception_handler(BizException)
  25. async def biz_exception_handler(request: Request, exc: BizException):
  26. return JSONResponse(
  27. status_code=exc.http_status,
  28. content={
  29. "code": exc.code,
  30. "message": exc.message,
  31. "data": exc.extra,
  32. },
  33. )
  34. @app.exception_handler(Exception)
  35. async def unhandled_exception_handler(request: Request, exc: Exception):
  36. logger.error("Unhandled exception", exc_info=exc)
  37. return JSONResponse(
  38. status_code=500,
  39. content={
  40. "code": 50000,
  41. "message": "Internal Server Error",
  42. "data": None,
  43. },
  44. )
  45. @app.exception_handler(RequestValidationError)
  46. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  47. """
  48. 接管 FastAPI 默认的 422 数据校验错误
  49. """
  50. # 1. 获取所有的错误详情
  51. errors = exc.errors()
  52. # 2. 提取第一个错误作为主要提示信息(通常用户只关心第一个报错)
  53. if errors:
  54. first_error = errors[0]
  55. # 处理路径:去掉 'body', 'query' 等前缀,只保留字段名,用点连接
  56. # 例如: ['body', 'user', 'age'] -> 'user.age'
  57. # 如果路径只有 ['body'] (比如传了空JSON),则通过 [1:] 过滤可能为空,要做个保护
  58. raw_path = first_error.get("loc", [])
  59. path_parts = [str(p) for p in raw_path if p not in ('body', 'query', 'path')]
  60. field_path = ".".join(path_parts) if path_parts else "request_body"
  61. err_msg = first_error.get("msg")
  62. # 优化提示文案
  63. message = f"Invalid parameter [{field_path}]: {err_msg}"
  64. else:
  65. message = "Invalid request parameters"
  66. # 3. 返回标准格式
  67. return JSONResponse(
  68. status_code=422,
  69. content={
  70. "code": 42200, # 或者是你们约定的参数错误码
  71. "message": message, # 转换后的人类可读提示
  72. "data": errors # 保留原始错误详情,方便前端开发排查
  73. },
  74. )
  75. @app.exception_handler(StarletteHTTPException)
  76. async def http_exception_handler(request: Request, exc: StarletteHTTPException):
  77. """
  78. 接管所有 HTTP 错误,包括 404 Not Found 和 405 Method Not Allowed
  79. """
  80. return JSONResponse(
  81. status_code=exc.status_code,
  82. content={
  83. # 这里你可以根据 status_code 生成你的业务 code,比如 404 -> 40400
  84. "code": exc.status_code*100,
  85. "message": exc.detail, # 这里通常是 "Not Found"
  86. "data": None,
  87. },
  88. )
  89. # -----------------------
  90. # CORS
  91. # -----------------------
  92. app.add_middleware(
  93. CORSMiddleware,
  94. allow_origins=["*"],
  95. allow_methods=["*"],
  96. allow_headers=["*"],
  97. )
  98. # -----------------------
  99. # Routers
  100. # -----------------------
  101. app.include_router(
  102. router.public_router,
  103. prefix="/api"
  104. )
  105. app.include_router(
  106. router.protected_router,
  107. prefix="/api",
  108. dependencies=[Depends(require_min_role(RoleLevel.user))]
  109. )
  110. app.include_router(
  111. router.admin_required_router,
  112. prefix="/api",
  113. dependencies=[Depends(require_min_role(RoleLevel.admin))]
  114. )
  115. # -----------------------
  116. # Swagger Bearer Token
  117. # -----------------------
  118. def custom_openapi():
  119. if app.openapi_schema:
  120. return app.openapi_schema
  121. openapi_schema = get_openapi(
  122. title=app.title,
  123. version="1.0.0",
  124. description="API documentation",
  125. routes=app.routes,
  126. )
  127. openapi_schema.setdefault("components", {})
  128. openapi_schema["components"]["securitySchemes"] = {
  129. "BearerAuth": {
  130. "type": "http",
  131. "scheme": "bearer",
  132. "bearerFormat": "JWT",
  133. }
  134. }
  135. for path in openapi_schema["paths"].values():
  136. for method in path.values():
  137. method.setdefault("security", [])
  138. method["security"].append({"BearerAuth": []})
  139. app.openapi_schema = openapi_schema
  140. return app.openapi_schema
  141. app.openapi = custom_openapi