main.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import asyncio
  2. from fastapi import FastAPI, Depends, Request
  3. from fastapi.responses import JSONResponse
  4. from fastapi.middleware.cors import CORSMiddleware
  5. from fastapi.openapi.utils import get_openapi
  6. from app.api import router
  7. from app.core.redis import get_redis_client
  8. from app.core.auth import RoleLevel, require_min_role
  9. from app.core.config import settings
  10. from app.core.payment import init_stripe
  11. from app.core.biz_exception import BizException
  12. from app.core.logger import logger
  13. from app.tasks.notification_task import notification_consumer
  14. app = FastAPI(title=settings.app_name)
  15. # -----------------------
  16. # Startup
  17. # -----------------------
  18. @app.on_event("startup")
  19. async def startup():
  20. # 如果 init_stripe 是 async
  21. init_stripe()
  22. # 全局 Redis 客户端
  23. @app.on_event("startup")
  24. async def startup_event():
  25. """
  26. FastAPI 启动时执行
  27. """
  28. # 启动后台消费任务
  29. redis_client = await get_redis_client()
  30. asyncio.create_task(notification_consumer(redis_client))
  31. print("🟢 Notification consumer started")
  32. # -----------------------
  33. # Exception Handlers
  34. # -----------------------
  35. @app.exception_handler(BizException)
  36. async def biz_exception_handler(request: Request, exc: BizException):
  37. return JSONResponse(
  38. status_code=exc.http_status,
  39. content={
  40. "code": exc.code,
  41. "message": exc.message,
  42. "data": exc.extra,
  43. },
  44. )
  45. @app.exception_handler(Exception)
  46. async def unhandled_exception_handler(request: Request, exc: Exception):
  47. logger.error("Unhandled exception", exc_info=exc)
  48. return JSONResponse(
  49. status_code=500,
  50. content={
  51. "code": 50000,
  52. "message": "Internal Server Error",
  53. "data": None,
  54. },
  55. )
  56. # -----------------------
  57. # CORS
  58. # -----------------------
  59. app.add_middleware(
  60. CORSMiddleware,
  61. allow_origins=["*"],
  62. allow_methods=["*"],
  63. allow_headers=["*"],
  64. )
  65. # -----------------------
  66. # Routers
  67. # -----------------------
  68. app.include_router(
  69. router.public_router,
  70. prefix="/api"
  71. )
  72. app.include_router(
  73. router.protected_router,
  74. prefix="/api",
  75. dependencies=[Depends(require_min_role(RoleLevel.user))]
  76. )
  77. app.include_router(
  78. router.admin_required_router,
  79. prefix="/api",
  80. dependencies=[Depends(require_min_role(RoleLevel.admin))]
  81. )
  82. # -----------------------
  83. # Swagger Bearer Token
  84. # -----------------------
  85. def custom_openapi():
  86. if app.openapi_schema:
  87. return app.openapi_schema
  88. openapi_schema = get_openapi(
  89. title=app.title,
  90. version="1.0.0",
  91. description="API documentation",
  92. routes=app.routes,
  93. )
  94. openapi_schema.setdefault("components", {})
  95. openapi_schema["components"]["securitySchemes"] = {
  96. "BearerAuth": {
  97. "type": "http",
  98. "scheme": "bearer",
  99. "bearerFormat": "JWT",
  100. }
  101. }
  102. for path in openapi_schema["paths"].values():
  103. for method in path.values():
  104. method.setdefault("security", [])
  105. method["security"].append({"BearerAuth": []})
  106. app.openapi_schema = openapi_schema
  107. return app.openapi_schema
  108. app.openapi = custom_openapi