main.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import os
  2. from fastapi import FastAPI, Depends, Request
  3. from fastapi.responses import JSONResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from fastapi.openapi.utils import get_openapi
  6. from fastapi.exceptions import RequestValidationError
  7. from starlette.exceptions import HTTPException as StarletteHTTPException
  8. from app.api import router
  9. from app.core.database import AsyncSessionLocal
  10. from app.core.auth import RoleLevel, require_min_role
  11. from app.core.config import settings
  12. from app.core.payment import init_stripe
  13. from app.core.biz_exception import BizException
  14. from app.core.logger import logger
  15. app = FastAPI(title=settings.app_name)
  16. # -----------------------
  17. # Startup
  18. # -----------------------
  19. @app.on_event("startup")
  20. async def startup():
  21. # 支付配置
  22. init_stripe()
  23. logger.info("🟢 Stripe config done")
  24. # Notification processing is handled by an external daemon.
  25. # -----------------------
  26. # Exception Handlers
  27. # -----------------------
  28. @app.exception_handler(BizException)
  29. async def biz_exception_handler(request: Request, exc: BizException):
  30. return JSONResponse(
  31. status_code=exc.http_status,
  32. content={
  33. "code": exc.code,
  34. "message": exc.message,
  35. "data": exc.extra,
  36. },
  37. )
  38. @app.exception_handler(Exception)
  39. async def unhandled_exception_handler(request: Request, exc: Exception):
  40. logger.error("Unhandled exception", exc_info=exc)
  41. return JSONResponse(
  42. status_code=500,
  43. content={
  44. "code": 50000,
  45. "message": "Internal Server Error",
  46. "data": None,
  47. },
  48. )
  49. @app.exception_handler(RequestValidationError)
  50. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  51. """
  52. 接管 FastAPI 默认的 422 数据校验错误
  53. """
  54. # 1. 获取所有的错误详情
  55. errors = exc.errors()
  56. # 2. 提取第一个错误作为主要提示信息(通常用户只关心第一个报错)
  57. if errors:
  58. first_error = errors[0]
  59. # 处理路径:去掉 'body', 'query' 等前缀,只保留字段名,用点连接
  60. # 例如: ['body', 'user', 'age'] -> 'user.age'
  61. # 如果路径只有 ['body'] (比如传了空JSON),则通过 [1:] 过滤可能为空,要做个保护
  62. raw_path = first_error.get("loc", [])
  63. path_parts = [str(p) for p in raw_path if p not in ('body', 'query', 'path')]
  64. field_path = ".".join(path_parts) if path_parts else "request_body"
  65. err_msg = first_error.get("msg")
  66. # 优化提示文案
  67. message = f"Invalid parameter [{field_path}]: {err_msg}"
  68. else:
  69. message = "Invalid request parameters"
  70. # 3. 返回标准格式
  71. return JSONResponse(
  72. status_code=422,
  73. content={
  74. "code": 42200, # 或者是你们约定的参数错误码
  75. "message": message, # 转换后的人类可读提示
  76. "data": errors # 保留原始错误详情,方便前端开发排查
  77. },
  78. )
  79. @app.exception_handler(StarletteHTTPException)
  80. async def http_exception_handler(request: Request, exc: StarletteHTTPException):
  81. """
  82. 接管所有 HTTP 错误,包括 404 Not Found 和 405 Method Not Allowed
  83. """
  84. return JSONResponse(
  85. status_code=exc.status_code,
  86. content={
  87. # 这里你可以根据 status_code 生成你的业务 code,比如 404 -> 40400
  88. "code": exc.status_code*100,
  89. "message": exc.detail, # 这里通常是 "Not Found"
  90. "data": None,
  91. },
  92. )
  93. # -----------------------
  94. # CORS
  95. # -----------------------
  96. app.add_middleware(
  97. CORSMiddleware,
  98. allow_origins=["*"],
  99. allow_methods=["*"],
  100. allow_headers=["*"],
  101. )
  102. # -----------------------
  103. # Routers
  104. # -----------------------
  105. app.include_router(
  106. router.public_router,
  107. prefix="/api"
  108. )
  109. app.include_router(
  110. router.protected_router,
  111. prefix="/api",
  112. dependencies=[Depends(require_min_role(RoleLevel.user))]
  113. )
  114. app.include_router(
  115. router.admin_required_router,
  116. prefix="/api",
  117. dependencies=[Depends(require_min_role(RoleLevel.admin))]
  118. )
  119. # -----------------------
  120. # Swagger Bearer Token
  121. # -----------------------
  122. def custom_openapi():
  123. if app.openapi_schema:
  124. return app.openapi_schema
  125. openapi_schema = get_openapi(
  126. title=app.title,
  127. version="1.0.0",
  128. description="API documentation",
  129. routes=app.routes,
  130. )
  131. openapi_schema.setdefault("components", {})
  132. openapi_schema["components"]["securitySchemes"] = {
  133. "BearerAuth": {
  134. "type": "http",
  135. "scheme": "bearer",
  136. "bearerFormat": "JWT",
  137. }
  138. }
  139. for path in openapi_schema["paths"].values():
  140. for method in path.values():
  141. method.setdefault("security", [])
  142. method["security"].append({"BearerAuth": []})
  143. app.openapi_schema = openapi_schema
  144. return app.openapi_schema
  145. app.openapi = custom_openapi