main.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import asyncio
  2. from fastapi import FastAPI, Depends, Request
  3. from fastapi.responses import JSONResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from fastapi.openapi.utils import get_openapi
  6. from fastapi.exceptions import RequestValidationError
  7. from starlette.exceptions import HTTPException as StarletteHTTPException
  8. from app.api import router
  9. from app.core.redis import get_redis_client
  10. from app.core.database import AsyncSessionLocal
  11. from app.core.auth import RoleLevel, require_min_role
  12. from app.core.config import settings
  13. from app.core.payment import init_stripe
  14. from app.core.biz_exception import BizException
  15. from app.core.logger import logger
  16. from app.tasks.notification_task import notification_consumer
  17. app = FastAPI(title=settings.app_name)
  18. # -----------------------
  19. # Startup
  20. # -----------------------
  21. @app.on_event("startup")
  22. async def startup():
  23. # 支付配置
  24. init_stripe()
  25. logger.info("🟢 Stripe config done")
  26. # 通知服务启动
  27. redis_client = await get_redis_client()
  28. asyncio.create_task(notification_consumer(AsyncSessionLocal, redis_client))
  29. logger.info("🟢 Notification consumer started")
  30. # -----------------------
  31. # Exception Handlers
  32. # -----------------------
  33. @app.exception_handler(BizException)
  34. async def biz_exception_handler(request: Request, exc: BizException):
  35. return JSONResponse(
  36. status_code=exc.http_status,
  37. content={
  38. "code": exc.code,
  39. "message": exc.message,
  40. "data": exc.extra,
  41. },
  42. )
  43. @app.exception_handler(Exception)
  44. async def unhandled_exception_handler(request: Request, exc: Exception):
  45. logger.error("Unhandled exception", exc_info=exc)
  46. return JSONResponse(
  47. status_code=500,
  48. content={
  49. "code": 50000,
  50. "message": "Internal Server Error",
  51. "data": None,
  52. },
  53. )
  54. @app.exception_handler(RequestValidationError)
  55. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  56. """
  57. 接管 FastAPI 默认的 422 数据校验错误
  58. """
  59. # 1. 获取所有的错误详情
  60. errors = exc.errors()
  61. # 2. 提取第一个错误作为主要提示信息(通常用户只关心第一个报错)
  62. if errors:
  63. first_error = errors[0]
  64. # 处理路径:去掉 'body', 'query' 等前缀,只保留字段名,用点连接
  65. # 例如: ['body', 'user', 'age'] -> 'user.age'
  66. # 如果路径只有 ['body'] (比如传了空JSON),则通过 [1:] 过滤可能为空,要做个保护
  67. raw_path = first_error.get("loc", [])
  68. path_parts = [str(p) for p in raw_path if p not in ('body', 'query', 'path')]
  69. field_path = ".".join(path_parts) if path_parts else "request_body"
  70. err_msg = first_error.get("msg")
  71. # 优化提示文案
  72. message = f"Invalid parameter [{field_path}]: {err_msg}"
  73. else:
  74. message = "Invalid request parameters"
  75. # 3. 返回标准格式
  76. return JSONResponse(
  77. status_code=422,
  78. content={
  79. "code": 42200, # 或者是你们约定的参数错误码
  80. "message": message, # 转换后的人类可读提示
  81. "data": errors # 保留原始错误详情,方便前端开发排查
  82. },
  83. )
  84. @app.exception_handler(StarletteHTTPException)
  85. async def http_exception_handler(request: Request, exc: StarletteHTTPException):
  86. """
  87. 接管所有 HTTP 错误,包括 404 Not Found 和 405 Method Not Allowed
  88. """
  89. return JSONResponse(
  90. status_code=exc.status_code,
  91. content={
  92. # 这里你可以根据 status_code 生成你的业务 code,比如 404 -> 40400
  93. "code": exc.status_code*100,
  94. "message": exc.detail, # 这里通常是 "Not Found"
  95. "data": None,
  96. },
  97. )
  98. # -----------------------
  99. # CORS
  100. # -----------------------
  101. app.add_middleware(
  102. CORSMiddleware,
  103. allow_origins=["*"],
  104. allow_methods=["*"],
  105. allow_headers=["*"],
  106. )
  107. # -----------------------
  108. # Routers
  109. # -----------------------
  110. app.include_router(
  111. router.public_router,
  112. prefix="/api"
  113. )
  114. app.include_router(
  115. router.protected_router,
  116. prefix="/api",
  117. dependencies=[Depends(require_min_role(RoleLevel.user))]
  118. )
  119. app.include_router(
  120. router.admin_required_router,
  121. prefix="/api",
  122. dependencies=[Depends(require_min_role(RoleLevel.admin))]
  123. )
  124. # -----------------------
  125. # Swagger Bearer Token
  126. # -----------------------
  127. def custom_openapi():
  128. if app.openapi_schema:
  129. return app.openapi_schema
  130. openapi_schema = get_openapi(
  131. title=app.title,
  132. version="1.0.0",
  133. description="API documentation",
  134. routes=app.routes,
  135. )
  136. openapi_schema.setdefault("components", {})
  137. openapi_schema["components"]["securitySchemes"] = {
  138. "BearerAuth": {
  139. "type": "http",
  140. "scheme": "bearer",
  141. "bearerFormat": "JWT",
  142. }
  143. }
  144. for path in openapi_schema["paths"].values():
  145. for method in path.values():
  146. method.setdefault("security", [])
  147. method["security"].append({"BearerAuth": []})
  148. app.openapi_schema = openapi_schema
  149. return app.openapi_schema
  150. app.openapi = custom_openapi