| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 |
- from contextlib import asynccontextmanager
- from fastapi import FastAPI, Depends, Request
- from fastapi.responses import JSONResponse
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.openapi.utils import get_openapi
- from fastapi.exceptions import RequestValidationError
- from starlette.exceptions import HTTPException as StarletteHTTPException
- from app.api import router
- from app.core.database import AsyncSessionLocal
- from app.core.auth import RoleLevel, require_min_role
- from app.core.config import settings
- from app.core.payment import init_stripe
- from app.core.biz_exception import BizException
- from app.core.logger import logger
- @asynccontextmanager
- async def lifespan(app: FastAPI):
- init_stripe()
- logger.info("🟢 Stripe config done")
- yield
- app = FastAPI(title=settings.app_name, lifespan=lifespan)
- # -----------------------
- # Exception Handlers
- # -----------------------
- @app.exception_handler(BizException)
- async def biz_exception_handler(request: Request, exc: BizException):
- return JSONResponse(
- status_code=exc.http_status,
- content={
- "code": exc.code,
- "message": exc.message,
- "data": exc.extra,
- },
- )
- @app.exception_handler(Exception)
- async def unhandled_exception_handler(request: Request, exc: Exception):
- logger.error("Unhandled exception", exc_info=exc)
- return JSONResponse(
- status_code=500,
- content={
- "code": 50000,
- "message": "Internal Server Error",
- "data": None,
- },
- )
- @app.exception_handler(RequestValidationError)
- async def validation_exception_handler(request: Request, exc: RequestValidationError):
- """
- 接管 FastAPI 默认的 422 数据校验错误
- """
- # 1. 获取所有的错误详情
- errors = exc.errors()
-
- # 2. 提取第一个错误作为主要提示信息(通常用户只关心第一个报错)
- if errors:
- first_error = errors[0]
- # 处理路径:去掉 'body', 'query' 等前缀,只保留字段名,用点连接
- # 例如: ['body', 'user', 'age'] -> 'user.age'
- # 如果路径只有 ['body'] (比如传了空JSON),则通过 [1:] 过滤可能为空,要做个保护
- raw_path = first_error.get("loc", [])
- path_parts = [str(p) for p in raw_path if p not in ('body', 'query', 'path')]
- field_path = ".".join(path_parts) if path_parts else "request_body"
-
- err_msg = first_error.get("msg")
-
- # 优化提示文案
- message = f"Invalid parameter [{field_path}]: {err_msg}"
- else:
- message = "Invalid request parameters"
- # 3. 返回标准格式
- return JSONResponse(
- status_code=422,
- content={
- "code": 42200, # 或者是你们约定的参数错误码
- "message": message, # 转换后的人类可读提示
- "data": errors # 保留原始错误详情,方便前端开发排查
- },
- )
-
- @app.exception_handler(StarletteHTTPException)
- async def http_exception_handler(request: Request, exc: StarletteHTTPException):
- """
- 接管所有 HTTP 错误,包括 404 Not Found 和 405 Method Not Allowed
- """
- return JSONResponse(
- status_code=exc.status_code,
- content={
- # 这里你可以根据 status_code 生成你的业务 code,比如 404 -> 40400
- "code": exc.status_code*100,
- "message": exc.detail, # 这里通常是 "Not Found"
- "data": None,
- },
- )
- # -----------------------
- # CORS
- # -----------------------
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # -----------------------
- # Routers
- # -----------------------
- app.include_router(
- router.public_router,
- prefix="/api"
- )
- app.include_router(
- router.protected_router,
- prefix="/api",
- dependencies=[Depends(require_min_role(RoleLevel.user))]
- )
- app.include_router(
- router.admin_required_router,
- prefix="/api",
- dependencies=[Depends(require_min_role(RoleLevel.admin))]
- )
- # -----------------------
- # Swagger Bearer Token
- # -----------------------
- def custom_openapi():
- if app.openapi_schema:
- return app.openapi_schema
- openapi_schema = get_openapi(
- title=app.title,
- version="1.0.0",
- description="API documentation",
- routes=app.routes,
- )
- openapi_schema.setdefault("components", {})
- openapi_schema["components"]["securitySchemes"] = {
- "BearerAuth": {
- "type": "http",
- "scheme": "bearer",
- "bearerFormat": "JWT",
- }
- }
- for path in openapi_schema["paths"].values():
- for method in path.values():
- method.setdefault("security", [])
- method["security"].append({"BearerAuth": []})
- app.openapi_schema = openapi_schema
- return app.openapi_schema
- app.openapi = custom_openapi
|