from contextlib import asynccontextmanager from fastapi import FastAPI, Depends, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi from fastapi.exceptions import RequestValidationError from starlette.exceptions import HTTPException as StarletteHTTPException from app.api import router from app.core.database import AsyncSessionLocal from app.core.auth import RoleLevel, require_min_role from app.core.config import settings from app.core.payment import init_stripe from app.core.biz_exception import BizException from app.core.logger import logger @asynccontextmanager async def lifespan(app: FastAPI): init_stripe() logger.info("🟢 Stripe config done") yield app = FastAPI(title=settings.app_name, lifespan=lifespan) # ----------------------- # Exception Handlers # ----------------------- @app.exception_handler(BizException) async def biz_exception_handler(request: Request, exc: BizException): return JSONResponse( status_code=exc.http_status, content={ "code": exc.code, "message": exc.message, "data": exc.extra, }, ) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): logger.error("Unhandled exception", exc_info=exc) return JSONResponse( status_code=500, content={ "code": 50000, "message": "Internal Server Error", "data": None, }, ) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError): """ 接管 FastAPI 默认的 422 数据校验错误 """ # 1. 获取所有的错误详情 errors = exc.errors() # 2. 提取第一个错误作为主要提示信息(通常用户只关心第一个报错) if errors: first_error = errors[0] # 处理路径:去掉 'body', 'query' 等前缀,只保留字段名,用点连接 # 例如: ['body', 'user', 'age'] -> 'user.age' # 如果路径只有 ['body'] (比如传了空JSON),则通过 [1:] 过滤可能为空,要做个保护 raw_path = first_error.get("loc", []) path_parts = [str(p) for p in raw_path if p not in ('body', 'query', 'path')] field_path = ".".join(path_parts) if path_parts else "request_body" err_msg = first_error.get("msg") # 优化提示文案 message = f"Invalid parameter [{field_path}]: {err_msg}" else: message = "Invalid request parameters" # 3. 返回标准格式 return JSONResponse( status_code=422, content={ "code": 42200, # 或者是你们约定的参数错误码 "message": message, # 转换后的人类可读提示 "data": errors # 保留原始错误详情,方便前端开发排查 }, ) @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request: Request, exc: StarletteHTTPException): """ 接管所有 HTTP 错误,包括 404 Not Found 和 405 Method Not Allowed """ return JSONResponse( status_code=exc.status_code, content={ # 这里你可以根据 status_code 生成你的业务 code,比如 404 -> 40400 "code": exc.status_code*100, "message": exc.detail, # 这里通常是 "Not Found" "data": None, }, ) # ----------------------- # CORS # ----------------------- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ----------------------- # Routers # ----------------------- app.include_router( router.public_router, prefix="/api" ) app.include_router( router.protected_router, prefix="/api", dependencies=[Depends(require_min_role(RoleLevel.user))] ) app.include_router( router.admin_required_router, prefix="/api", dependencies=[Depends(require_min_role(RoleLevel.admin))] ) # ----------------------- # Swagger Bearer Token # ----------------------- def custom_openapi(): if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( title=app.title, version="1.0.0", description="API documentation", routes=app.routes, ) openapi_schema.setdefault("components", {}) openapi_schema["components"]["securitySchemes"] = { "BearerAuth": { "type": "http", "scheme": "bearer", "bearerFormat": "JWT", } } for path in openapi_schema["paths"].values(): for method in path.values(): method.setdefault("security", []) method["security"].append({"BearerAuth": []}) app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi