| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136 |
- import asyncio
- from fastapi import FastAPI, Depends, Request
- from fastapi.responses import JSONResponse
- from fastapi.middleware.cors import CORSMiddleware
- from fastapi.openapi.utils import get_openapi
- from app.api import router
- from app.core.redis import get_redis_client
- from app.core.auth import RoleLevel, require_min_role
- from app.core.config import settings
- from app.core.payment import init_stripe
- from app.core.biz_exception import BizException
- from app.core.logger import logger
- from app.tasks.notification_task import notification_consumer
- app = FastAPI(title=settings.app_name)
- # -----------------------
- # Startup
- # -----------------------
- @app.on_event("startup")
- async def startup():
- # 如果 init_stripe 是 async
- init_stripe()
-
-
- # 全局 Redis 客户端
- @app.on_event("startup")
- async def startup_event():
- """
- FastAPI 启动时执行
- """
- # 启动后台消费任务
- redis_client = await get_redis_client()
- asyncio.create_task(notification_consumer(redis_client))
- print("🟢 Notification consumer started")
- # -----------------------
- # Exception Handlers
- # -----------------------
- @app.exception_handler(BizException)
- async def biz_exception_handler(request: Request, exc: BizException):
- return JSONResponse(
- status_code=exc.http_status,
- content={
- "code": exc.code,
- "message": exc.message,
- "data": exc.extra,
- },
- )
- @app.exception_handler(Exception)
- async def unhandled_exception_handler(request: Request, exc: Exception):
- logger.error("Unhandled exception", exc_info=exc)
- return JSONResponse(
- status_code=500,
- content={
- "code": 50000,
- "message": "Internal Server Error",
- "data": None,
- },
- )
- # -----------------------
- # CORS
- # -----------------------
- app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"],
- allow_methods=["*"],
- allow_headers=["*"],
- )
- # -----------------------
- # Routers
- # -----------------------
- app.include_router(
- router.public_router,
- prefix="/api"
- )
- app.include_router(
- router.protected_router,
- prefix="/api",
- dependencies=[Depends(require_min_role(RoleLevel.user))]
- )
- app.include_router(
- router.admin_required_router,
- prefix="/api",
- dependencies=[Depends(require_min_role(RoleLevel.admin))]
- )
- # -----------------------
- # Swagger Bearer Token
- # -----------------------
- def custom_openapi():
- if app.openapi_schema:
- return app.openapi_schema
- openapi_schema = get_openapi(
- title=app.title,
- version="1.0.0",
- description="API documentation",
- routes=app.routes,
- )
- openapi_schema.setdefault("components", {})
- openapi_schema["components"]["securitySchemes"] = {
- "BearerAuth": {
- "type": "http",
- "scheme": "bearer",
- "bearerFormat": "JWT",
- }
- }
- for path in openapi_schema["paths"].values():
- for method in path.values():
- method.setdefault("security", [])
- method["security"].append({"BearerAuth": []})
- app.openapi_schema = openapi_schema
- return app.openapi_schema
- app.openapi = custom_openapi
|