import asyncio from fastapi import FastAPI, Depends, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.openapi.utils import get_openapi from app.api import router from app.core.redis import get_redis_client from app.core.auth import RoleLevel, require_min_role from app.core.config import settings from app.core.payment import init_stripe from app.core.biz_exception import BizException from app.core.logger import logger from app.tasks.notification_task import notification_consumer app = FastAPI(title=settings.app_name) # ----------------------- # Startup # ----------------------- @app.on_event("startup") async def startup(): # 如果 init_stripe 是 async init_stripe() # 全局 Redis 客户端 @app.on_event("startup") async def startup_event(): """ FastAPI 启动时执行 """ # 启动后台消费任务 redis_client = await get_redis_client() asyncio.create_task(notification_consumer(redis_client)) print("🟢 Notification consumer started") # ----------------------- # Exception Handlers # ----------------------- @app.exception_handler(BizException) async def biz_exception_handler(request: Request, exc: BizException): return JSONResponse( status_code=exc.http_status, content={ "code": exc.code, "message": exc.message, "data": exc.extra, }, ) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): logger.error("Unhandled exception", exc_info=exc) return JSONResponse( status_code=500, content={ "code": 50000, "message": "Internal Server Error", "data": None, }, ) # ----------------------- # CORS # ----------------------- app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ----------------------- # Routers # ----------------------- app.include_router( router.public_router, prefix="/api" ) app.include_router( router.protected_router, prefix="/api", dependencies=[Depends(require_min_role(RoleLevel.user))] ) app.include_router( router.admin_required_router, prefix="/api", dependencies=[Depends(require_min_role(RoleLevel.admin))] ) # ----------------------- # Swagger Bearer Token # ----------------------- def custom_openapi(): if app.openapi_schema: return app.openapi_schema openapi_schema = get_openapi( title=app.title, version="1.0.0", description="API documentation", routes=app.routes, ) openapi_schema.setdefault("components", {}) openapi_schema["components"]["securitySchemes"] = { "BearerAuth": { "type": "http", "scheme": "bearer", "bearerFormat": "JWT", } } for path in openapi_schema["paths"].values(): for method in path.values(): method.setdefault("security", []) method["security"].append({"BearerAuth": []}) app.openapi_schema = openapi_schema return app.openapi_schema app.openapi = custom_openapi