main.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. from contextlib import asynccontextmanager
  2. from fastapi import FastAPI, Depends, Request
  3. from fastapi.responses import JSONResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from fastapi.openapi.utils import get_openapi
  6. from fastapi.exceptions import RequestValidationError
  7. from starlette.exceptions import HTTPException as StarletteHTTPException
  8. from app.api import router
  9. from app.core.database import AsyncSessionLocal
  10. from app.core.auth import RoleLevel, require_min_role
  11. from app.core.config import settings
  12. from app.core.payment import init_stripe
  13. from app.core.biz_exception import BizException
  14. from app.core.logger import logger
  15. @asynccontextmanager
  16. async def lifespan(app: FastAPI):
  17. init_stripe()
  18. logger.info("🟢 Stripe config done")
  19. from app.services.scheduler_handlers import scheduler_service, visametric_slot_expire_checker
  20. scheduler_service.start()
  21. logger.info("🟢 Scheduler started")
  22. yield
  23. from app.services.scheduler_service import scheduler_service
  24. scheduler_service.stop()
  25. logger.info("🟢 Scheduler stopped")
  26. app = FastAPI(title=settings.app_name, lifespan=lifespan)
  27. # -----------------------
  28. # Exception Handlers
  29. # -----------------------
  30. @app.exception_handler(BizException)
  31. async def biz_exception_handler(request: Request, exc: BizException):
  32. return JSONResponse(
  33. status_code=exc.http_status,
  34. content={
  35. "code": exc.code,
  36. "message": exc.message,
  37. "data": exc.extra,
  38. },
  39. )
  40. @app.exception_handler(Exception)
  41. async def unhandled_exception_handler(request: Request, exc: Exception):
  42. logger.error("Unhandled exception", exc_info=exc)
  43. return JSONResponse(
  44. status_code=500,
  45. content={
  46. "code": 50000,
  47. "message": "Internal Server Error",
  48. "data": None,
  49. },
  50. )
  51. @app.exception_handler(RequestValidationError)
  52. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  53. """
  54. 接管 FastAPI 默认的 422 数据校验错误
  55. """
  56. # 1. 获取所有的错误详情
  57. errors = exc.errors()
  58. # 2. 提取第一个错误作为主要提示信息(通常用户只关心第一个报错)
  59. if errors:
  60. first_error = errors[0]
  61. # 处理路径:去掉 'body', 'query' 等前缀,只保留字段名,用点连接
  62. # 例如: ['body', 'user', 'age'] -> 'user.age'
  63. # 如果路径只有 ['body'] (比如传了空JSON),则通过 [1:] 过滤可能为空,要做个保护
  64. raw_path = first_error.get("loc", [])
  65. path_parts = [str(p) for p in raw_path if p not in ('body', 'query', 'path')]
  66. field_path = ".".join(path_parts) if path_parts else "request_body"
  67. err_msg = first_error.get("msg")
  68. # 优化提示文案
  69. message = f"Invalid parameter [{field_path}]: {err_msg}"
  70. else:
  71. message = "Invalid request parameters"
  72. # 3. 返回标准格式
  73. return JSONResponse(
  74. status_code=422,
  75. content={
  76. "code": 42200, # 或者是你们约定的参数错误码
  77. "message": message, # 转换后的人类可读提示
  78. "data": errors # 保留原始错误详情,方便前端开发排查
  79. },
  80. )
  81. @app.exception_handler(StarletteHTTPException)
  82. async def http_exception_handler(request: Request, exc: StarletteHTTPException):
  83. """
  84. 接管所有 HTTP 错误,包括 404 Not Found 和 405 Method Not Allowed
  85. """
  86. return JSONResponse(
  87. status_code=exc.status_code,
  88. content={
  89. # 这里你可以根据 status_code 生成你的业务 code,比如 404 -> 40400
  90. "code": exc.status_code*100,
  91. "message": exc.detail, # 这里通常是 "Not Found"
  92. "data": None,
  93. },
  94. )
  95. # -----------------------
  96. # CORS
  97. # -----------------------
  98. app.add_middleware(
  99. CORSMiddleware,
  100. allow_origins=["*"],
  101. allow_methods=["*"],
  102. allow_headers=["*"],
  103. )
  104. # -----------------------
  105. # Routers
  106. # -----------------------
  107. app.include_router(
  108. router.public_router,
  109. prefix="/api"
  110. )
  111. app.include_router(
  112. router.protected_router,
  113. prefix="/api",
  114. dependencies=[Depends(require_min_role(RoleLevel.user))]
  115. )
  116. app.include_router(
  117. router.admin_required_router,
  118. prefix="/api",
  119. dependencies=[Depends(require_min_role(RoleLevel.admin))]
  120. )
  121. # -----------------------
  122. # Swagger Bearer Token
  123. # -----------------------
  124. def custom_openapi():
  125. if app.openapi_schema:
  126. return app.openapi_schema
  127. openapi_schema = get_openapi(
  128. title=app.title,
  129. version="1.0.0",
  130. description="API documentation",
  131. routes=app.routes,
  132. )
  133. openapi_schema.setdefault("components", {})
  134. openapi_schema["components"]["securitySchemes"] = {
  135. "BearerAuth": {
  136. "type": "http",
  137. "scheme": "bearer",
  138. "bearerFormat": "JWT",
  139. }
  140. }
  141. for path in openapi_schema["paths"].values():
  142. for method in path.values():
  143. method.setdefault("security", [])
  144. method["security"].append({"BearerAuth": []})
  145. app.openapi_schema = openapi_schema
  146. return app.openapi_schema
  147. app.openapi = custom_openapi