main.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import asyncio
  2. import os
  3. from fastapi import FastAPI, Depends, Request
  4. from fastapi.responses import JSONResponse
  5. from fastapi.middleware.cors import CORSMiddleware
  6. from fastapi.openapi.utils import get_openapi
  7. from fastapi.exceptions import RequestValidationError
  8. from starlette.exceptions import HTTPException as StarletteHTTPException
  9. from app.api import router
  10. from app.core.redis import get_redis_client
  11. from app.core.database import AsyncSessionLocal
  12. from app.core.auth import RoleLevel, require_min_role
  13. from app.core.config import settings
  14. from app.core.payment import init_stripe
  15. from app.core.biz_exception import BizException
  16. from app.core.logger import logger
  17. from app.tasks.notification_task import notification_consumer
  18. app = FastAPI(title=settings.app_name)
  19. # -----------------------
  20. # Startup
  21. # -----------------------
  22. @app.on_event("startup")
  23. async def startup():
  24. # 支付配置
  25. init_stripe()
  26. logger.info("🟢 Stripe config done")
  27. # 通知服务启动
  28. if os.environ.get("RUN_ON_MASTER", "1") == "1" and os.getppid() == 1:
  29. redis_client = await get_redis_client()
  30. asyncio.create_task(notification_consumer(AsyncSessionLocal, redis_client))
  31. logger.info("🟢 Notification consumer started")
  32. # -----------------------
  33. # Exception Handlers
  34. # -----------------------
  35. @app.exception_handler(BizException)
  36. async def biz_exception_handler(request: Request, exc: BizException):
  37. return JSONResponse(
  38. status_code=exc.http_status,
  39. content={
  40. "code": exc.code,
  41. "message": exc.message,
  42. "data": exc.extra,
  43. },
  44. )
  45. @app.exception_handler(Exception)
  46. async def unhandled_exception_handler(request: Request, exc: Exception):
  47. logger.error("Unhandled exception", exc_info=exc)
  48. return JSONResponse(
  49. status_code=500,
  50. content={
  51. "code": 50000,
  52. "message": "Internal Server Error",
  53. "data": None,
  54. },
  55. )
  56. @app.exception_handler(RequestValidationError)
  57. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  58. """
  59. 接管 FastAPI 默认的 422 数据校验错误
  60. """
  61. # 1. 获取所有的错误详情
  62. errors = exc.errors()
  63. # 2. 提取第一个错误作为主要提示信息(通常用户只关心第一个报错)
  64. if errors:
  65. first_error = errors[0]
  66. # 处理路径:去掉 'body', 'query' 等前缀,只保留字段名,用点连接
  67. # 例如: ['body', 'user', 'age'] -> 'user.age'
  68. # 如果路径只有 ['body'] (比如传了空JSON),则通过 [1:] 过滤可能为空,要做个保护
  69. raw_path = first_error.get("loc", [])
  70. path_parts = [str(p) for p in raw_path if p not in ('body', 'query', 'path')]
  71. field_path = ".".join(path_parts) if path_parts else "request_body"
  72. err_msg = first_error.get("msg")
  73. # 优化提示文案
  74. message = f"Invalid parameter [{field_path}]: {err_msg}"
  75. else:
  76. message = "Invalid request parameters"
  77. # 3. 返回标准格式
  78. return JSONResponse(
  79. status_code=422,
  80. content={
  81. "code": 42200, # 或者是你们约定的参数错误码
  82. "message": message, # 转换后的人类可读提示
  83. "data": errors # 保留原始错误详情,方便前端开发排查
  84. },
  85. )
  86. @app.exception_handler(StarletteHTTPException)
  87. async def http_exception_handler(request: Request, exc: StarletteHTTPException):
  88. """
  89. 接管所有 HTTP 错误,包括 404 Not Found 和 405 Method Not Allowed
  90. """
  91. return JSONResponse(
  92. status_code=exc.status_code,
  93. content={
  94. # 这里你可以根据 status_code 生成你的业务 code,比如 404 -> 40400
  95. "code": exc.status_code*100,
  96. "message": exc.detail, # 这里通常是 "Not Found"
  97. "data": None,
  98. },
  99. )
  100. # -----------------------
  101. # CORS
  102. # -----------------------
  103. app.add_middleware(
  104. CORSMiddleware,
  105. allow_origins=["*"],
  106. allow_methods=["*"],
  107. allow_headers=["*"],
  108. )
  109. # -----------------------
  110. # Routers
  111. # -----------------------
  112. app.include_router(
  113. router.public_router,
  114. prefix="/api"
  115. )
  116. app.include_router(
  117. router.protected_router,
  118. prefix="/api",
  119. dependencies=[Depends(require_min_role(RoleLevel.user))]
  120. )
  121. app.include_router(
  122. router.admin_required_router,
  123. prefix="/api",
  124. dependencies=[Depends(require_min_role(RoleLevel.admin))]
  125. )
  126. # -----------------------
  127. # Swagger Bearer Token
  128. # -----------------------
  129. def custom_openapi():
  130. if app.openapi_schema:
  131. return app.openapi_schema
  132. openapi_schema = get_openapi(
  133. title=app.title,
  134. version="1.0.0",
  135. description="API documentation",
  136. routes=app.routes,
  137. )
  138. openapi_schema.setdefault("components", {})
  139. openapi_schema["components"]["securitySchemes"] = {
  140. "BearerAuth": {
  141. "type": "http",
  142. "scheme": "bearer",
  143. "bearerFormat": "JWT",
  144. }
  145. }
  146. for path in openapi_schema["paths"].values():
  147. for method in path.values():
  148. method.setdefault("security", [])
  149. method["security"].append({"BearerAuth": []})
  150. app.openapi_schema = openapi_schema
  151. return app.openapi_schema
  152. app.openapi = custom_openapi