jerry 4 месяцев назад
Родитель
Сommit
882a5ce465
53 измененных файлов с 3770 добавлено и 2071 удалено
  1. 1 2
      .env
  2. 263 193
      app/api/router.py
  3. 25 18
      app/core/auth.py
  4. 35 8
      app/core/config.py
  5. 26 25
      app/core/database.py
  6. 7 6
      app/core/redis.py
  7. 45 13
      app/main.py
  8. 0 46
      app/models/auto_booking.py
  9. 24 5
      app/models/payment.py
  10. 18 0
      app/models/payment_confirmation.py
  11. 1 0
      app/models/product_routing.py
  12. 1 1
      app/models/session.py
  13. 2 4
      app/models/verification_token.py
  14. 1 2
      app/schemas/auth.py
  15. 0 52
      app/schemas/auto_booking.py
  16. 4 4
      app/schemas/payment.py
  17. 30 0
      app/schemas/payment_confirmation.py
  18. 2 1
      app/schemas/product_routing.py
  19. 191 133
      app/services/auth_service.py
  20. 0 72
      app/services/auto_booking_service.py
  21. 37 16
      app/services/card_service.py
  22. 85 24
      app/services/configuration_service.py
  23. 469 441
      app/services/email_authorizations_service.py
  24. 57 21
      app/services/http_session_service.py
  25. 18 8
      app/services/notification_service.py
  26. 128 126
      app/services/order_service.py
  27. 103 55
      app/services/payment_provider_service.py
  28. 130 31
      app/services/payment_qr_service.py
  29. 384 78
      app/services/payment_service.py
  30. 37 12
      app/services/product_routing_service.py
  31. 63 30
      app/services/product_service.py
  32. 54 22
      app/services/schema_service.py
  33. 86 45
      app/services/seaweedfs_service.py
  34. 26 20
      app/services/session_service.py
  35. 45 18
      app/services/short_url_service.py
  36. 30 7
      app/services/slot_snapshot_service.py
  37. 52 22
      app/services/sms_service.py
  38. 157 90
      app/services/statistics_service.py
  39. 52 18
      app/services/task_service.py
  40. 17 12
      app/services/telegram_service.py
  41. 238 78
      app/services/ticket_service.py
  42. 125 56
      app/services/troov_service.py
  43. 70 34
      app/services/user_service.py
  44. 103 53
      app/services/vas_task_service.py
  45. 113 116
      app/services/webhook_service.py
  46. 31 12
      app/services/wechat_service.py
  47. 256 0
      app/tasks/notification_task.py
  48. 18 9
      app/utils/pagination.py
  49. 28 14
      app/utils/redis_utils.py
  50. 18 4
      app/utils/search.py
  51. 1 1
      app/utils/validation_utils.py
  52. 11 0
      app/utils/wrappers.py
  53. 52 13
      starter.py

+ 1 - 2
.env

@@ -1,5 +1,4 @@
-DATABASE_URL=mysql://root:GqLLL7Bofj0WaaOpp.0@visafly.top:3306/book_user_info?charset=utf8mb4
+DATABASE_URL=mysql+asyncmy://root:GqLLL7Bofj0WaaOpp.0@visafly.top:3306/book_user_info?charset=utf8mb4
 REDIS_URL=redis://:STEs2x6ML0U1HlpE9SojM6YU7QPhqzY8@45.137.220.138:6379/0
 REDIS_URL=redis://:STEs2x6ML0U1HlpE9SojM6YU7QPhqzY8@45.137.220.138:6379/0
-API_TOKEN=7x9EjFpmv7GjZc6AfVeqxuUBANpqkpkHAtxJM7CAW5oZhs0nEyCJBy39N4XXs5hgfYWXw3jFrcgXqQ42HAx9Qvwtk9vC2GvKBbWz
 OPENAI_API_KEY=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
 OPENAI_API_KEY=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
 STRIPE_API_KEY=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
 STRIPE_API_KEY=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

Разница между файлами не показана из-за своего большого размера
+ 263 - 193
app/api/router.py


+ 25 - 18
app/core/auth.py

@@ -1,40 +1,47 @@
 from enum import IntEnum
 from enum import IntEnum
-from fastapi import Depends, HTTPException, status
+from fastapi import Depends
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
-from app.core.config import settings
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
+
 from app.core.database import get_db
 from app.core.database import get_db
+from app.core.biz_exception import PermissionDeniedError
 from app.services.session_service import SessionService
 from app.services.session_service import SessionService
 
 
-
-security = HTTPBearer()
+security = HTTPBearer(auto_error=False)
 
 
 
 
 class RoleLevel(IntEnum):
 class RoleLevel(IntEnum):
     user = 10
     user = 10
     admin = 100
     admin = 100
 
 
+
 ROLE_LEVEL_MAP = {
 ROLE_LEVEL_MAP = {
     "user": 10,
     "user": 10,
     "admin": 100,
     "admin": 100,
 }
 }
 
 
-def require_min_role(min_role: RoleLevel):
-    def checker(user=Depends(get_current_user)):
-        current_level = ROLE_LEVEL_MAP.get(user.role, 0)
-        if current_level < min_role:
-            raise PermissionDeniedError("Permission denied")
-        return user
-    return checker
 
 
-
-def get_current_user(
+async def get_current_user(
     credentials: HTTPAuthorizationCredentials = Depends(security),
     credentials: HTTPAuthorizationCredentials = Depends(security),
-    db: Session = Depends(get_db)
+    db: AsyncSession = Depends(get_db),
 ):
 ):
+    if not credentials:
+        raise PermissionDeniedError("Missing token")
+
     token = credentials.credentials
     token = credentials.credentials
-    user = SessionService().get_user_by_token(db, token)
+    user = await SessionService().get_user_by_token(db, token)
+
     if not user:
     if not user:
         raise PermissionDeniedError("Invalid or expired token")
         raise PermissionDeniedError("Invalid or expired token")
-    return user
+
+    return user
+
+
+def require_min_role(min_role: RoleLevel):
+    async def checker(user=Depends(get_current_user)):
+        current_level = ROLE_LEVEL_MAP.get(user.role, 0)
+        if current_level < min_role:
+            raise PermissionDeniedError("Permission denied")
+        return user
+
+    return checker

+ 35 - 8
app/core/config.py

@@ -1,18 +1,45 @@
-from pydantic_settings import BaseSettings
+from functools import lru_cache
+from pydantic_settings import BaseSettings, SettingsConfigDict
+from pydantic import Field
+
 
 
 class Settings(BaseSettings):
 class Settings(BaseSettings):
+    # -----------------------
+    # App
+    # -----------------------
     app_name: str = "MyApp"
     app_name: str = "MyApp"
     debug: bool = False
     debug: bool = False
-    database_url: str
+
+    # -----------------------
+    # Database / Cache
+    # -----------------------
+    database_url: str = Field(..., description="Async database DSN")
     redis_url: str
     redis_url: str
-    api_token: str
+
+    # -----------------------
+    # Security / API Keys
+    # -----------------------
     openai_api_key: str
     openai_api_key: str
     stripe_api_key: str
     stripe_api_key: str
-    
-    class Config:
-        env_file = ".env"
 
 
-settings = Settings()
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+    )
+
+
+@lru_cache
+def get_settings() -> Settings:
+    """
+    避免多次实例化 Settings(FastAPI 官方推荐)
+    """
+    return Settings()
+
 
 
-base_currency = "EUR"
+settings = get_settings()
 
 
+# -----------------------
+# Global constants
+# -----------------------
+BASE_CURRENCY = "EUR"

+ 26 - 25
app/core/database.py

@@ -1,25 +1,29 @@
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker, declarative_base
+from sqlalchemy.ext.asyncio import (
+    create_async_engine,
+    async_sessionmaker,
+    AsyncSession,
+)
+from sqlalchemy.orm import declarative_base
 from app.core.config import settings
 from app.core.config import settings
 
 
 # =========================
 # =========================
-# 数据库初始化
+# Async Engine
 # =========================
 # =========================
-# 建立 Engine
-engine = create_engine(
-    settings.database_url,
-    echo=settings.debug,        # 是否打印 SQL 日志
-    pool_pre_ping=True,         # 检测断开的连接
-    pool_recycle=1800,          # 连接回收时间,防止 MySQL 8 小时断开
-    future=True                 # 启用 SQLAlchemy 2.0 风格
+engine = create_async_engine(
+    settings.database_url,      # ⚠️ 必须是 async URL
+    echo=settings.debug,
+    pool_pre_ping=True,
+    pool_recycle=1800,
 )
 )
 
 
-# 建立 Session 工厂
-SessionLocal = sessionmaker(
-    autocommit=False,
-    autoflush=False,
+# =========================
+# Async Session 工厂
+# =========================
+AsyncSessionLocal = async_sessionmaker(
     bind=engine,
     bind=engine,
-    expire_on_commit=False
+    class_=AsyncSession,
+    autoflush=False,
+    expire_on_commit=False,
 )
 )
 
 
 # ORM 基类
 # ORM 基类
@@ -27,14 +31,11 @@ Base = declarative_base()
 
 
 
 
 # =========================
 # =========================
-# 数据库依赖
+# FastAPI 依赖
 # =========================
 # =========================
-def get_db():
-    """
-    FastAPI 的依赖注入函数,用于在请求周期内创建并关闭数据库会话。
-    """
-    db = SessionLocal()
-    try:
-        yield db
-    finally:
-        db.close()
+async def get_db() -> AsyncSession:
+    async with AsyncSessionLocal() as session:
+        try:
+            yield session
+        finally:
+            await session.close()

+ 7 - 6
app/core/redis.py

@@ -1,14 +1,15 @@
+
 from typing import Optional
 from typing import Optional
-from redis import Redis
+from redis.asyncio import Redis
 from app.core.config import settings
 from app.core.config import settings
 
 
 _redis_client: Optional[Redis] = None
 _redis_client: Optional[Redis] = None
 
 
-def get_redis_client() -> Redis:
-    """
-    同步依赖(FastAPI 可以直接注入)
-    """
+async def get_redis_client() -> Redis:
     global _redis_client
     global _redis_client
     if _redis_client is None:
     if _redis_client is None:
-        _redis_client = Redis.from_url(settings.redis_url, decode_responses=True)
+        _redis_client = Redis.from_url(
+            settings.redis_url,
+            decode_responses=True
+        )
     return _redis_client
     return _redis_client

+ 45 - 13
app/main.py

@@ -1,23 +1,48 @@
+import asyncio
+
 from fastapi import FastAPI, Depends, Request
 from fastapi import FastAPI, Depends, Request
 from fastapi.responses import JSONResponse
 from fastapi.responses import JSONResponse
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.openapi.utils import get_openapi
 from fastapi.openapi.utils import get_openapi
-from fastapi.security import HTTPBearer
 
 
 from app.api import router
 from app.api import router
+from app.core.redis import get_redis_client
 from app.core.auth import RoleLevel, require_min_role
 from app.core.auth import RoleLevel, require_min_role
 from app.core.config import settings
 from app.core.config import settings
 from app.core.payment import init_stripe
 from app.core.payment import init_stripe
 from app.core.biz_exception import BizException
 from app.core.biz_exception import BizException
 from app.core.logger import logger
 from app.core.logger import logger
+from app.tasks.notification_task import notification_consumer
 
 
 
 
 app = FastAPI(title=settings.app_name)
 app = FastAPI(title=settings.app_name)
 
 
+# -----------------------
+# Startup
+# -----------------------
 @app.on_event("startup")
 @app.on_event("startup")
-def startup():
+async def startup():
+    # 如果 init_stripe 是 async
     init_stripe()
     init_stripe()
     
     
+ 
+# 全局 Redis 客户端
+
+
+@app.on_event("startup")
+async def startup_event():
+    """
+    FastAPI 启动时执行
+    """
+    # 启动后台消费任务
+    redis_client = await get_redis_client()
+    asyncio.create_task(notification_consumer(redis_client))
+    print("🟢 Notification consumer started")   
+
+
+# -----------------------
+# Exception Handlers
+# -----------------------
 @app.exception_handler(BizException)
 @app.exception_handler(BizException)
 async def biz_exception_handler(request: Request, exc: BizException):
 async def biz_exception_handler(request: Request, exc: BizException):
     return JSONResponse(
     return JSONResponse(
@@ -29,10 +54,10 @@ async def biz_exception_handler(request: Request, exc: BizException):
         },
         },
     )
     )
 
 
+
 @app.exception_handler(Exception)
 @app.exception_handler(Exception)
 async def unhandled_exception_handler(request: Request, exc: Exception):
 async def unhandled_exception_handler(request: Request, exc: Exception):
-    # ⚠️ 一定要打日志
-    logger.error("Unhandled exception")
+    logger.error("Unhandled exception", exc_info=exc)
 
 
     return JSONResponse(
     return JSONResponse(
         status_code=500,
         status_code=500,
@@ -43,8 +68,9 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
         },
         },
     )
     )
 
 
+
 # -----------------------
 # -----------------------
-# CORS(可选)
+# CORS
 # -----------------------
 # -----------------------
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
@@ -53,52 +79,58 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+
 # -----------------------
 # -----------------------
-# 路由注册
+# Routers
 # -----------------------
 # -----------------------
-# 公共路由,不鉴权
 app.include_router(
 app.include_router(
     router.public_router,
     router.public_router,
     prefix="/api"
     prefix="/api"
 )
 )
-# 需要鉴权的路由
+
 app.include_router(
 app.include_router(
     router.protected_router,
     router.protected_router,
     prefix="/api",
     prefix="/api",
     dependencies=[Depends(require_min_role(RoleLevel.user))]
     dependencies=[Depends(require_min_role(RoleLevel.user))]
 )
 )
 
 
-# 需要管理员权限的路由
 app.include_router(
 app.include_router(
     router.admin_required_router,
     router.admin_required_router,
     prefix="/api",
     prefix="/api",
     dependencies=[Depends(require_min_role(RoleLevel.admin))]
     dependencies=[Depends(require_min_role(RoleLevel.admin))]
 )
 )
 
 
+
 # -----------------------
 # -----------------------
-# Swagger 支持 Bearer Token
+# Swagger Bearer Token
 # -----------------------
 # -----------------------
 def custom_openapi():
 def custom_openapi():
     if app.openapi_schema:
     if app.openapi_schema:
         return app.openapi_schema
         return app.openapi_schema
+
     openapi_schema = get_openapi(
     openapi_schema = get_openapi(
         title=app.title,
         title=app.title,
         version="1.0.0",
         version="1.0.0",
         description="API documentation",
         description="API documentation",
         routes=app.routes,
         routes=app.routes,
     )
     )
-    # 添加全局 Bearer
+
+    openapi_schema.setdefault("components", {})
     openapi_schema["components"]["securitySchemes"] = {
     openapi_schema["components"]["securitySchemes"] = {
         "BearerAuth": {
         "BearerAuth": {
             "type": "http",
             "type": "http",
             "scheme": "bearer",
             "scheme": "bearer",
-            "bearerFormat": "JWT"
+            "bearerFormat": "JWT",
         }
         }
     }
     }
+
     for path in openapi_schema["paths"].values():
     for path in openapi_schema["paths"].values():
         for method in path.values():
         for method in path.values():
-            method["security"] = [{"BearerAuth": []}]
+            method.setdefault("security", [])
+            method["security"].append({"BearerAuth": []})
+
     app.openapi_schema = openapi_schema
     app.openapi_schema = openapi_schema
     return app.openapi_schema
     return app.openapi_schema
 
 
+
 app.openapi = custom_openapi
 app.openapi = custom_openapi

+ 0 - 46
app/models/auto_booking.py

@@ -1,46 +0,0 @@
-from sqlalchemy import Column, BigInteger, String, Integer, Text, Date, DateTime
-from app.core.database import Base
-
-class AutoBooking(Base):
-    __tablename__ = "auto_booking"
-
-    id = Column(BigInteger, primary_key=True, autoincrement=True)
-    provider = Column(String(100))
-    visa_center = Column(String(100))
-    order_no = Column(String(100))
-    social_account = Column(String(100))
-    account = Column(String(100))
-    password = Column(String(100))
-    last_name = Column(String(100))
-    first_name = Column(String(100))
-    gender = Column(String(10))
-    birthday = Column(Date)
-    email = Column(String(150))
-    alias_email = Column(String(150))
-    phone_country_code = Column(String(20))
-    phone_no = Column(String(50))
-    passport_no = Column(String(50))
-    nationality = Column(String(50))
-    passport_expiry_date = Column(Date)
-    address_line1 = Column(Text)
-    address_line2 = Column(Text)
-    state = Column(String(100))
-    city = Column(String(100))
-    postcode = Column(String(100))
-    travel_date = Column(Date)
-    cover_letter = Column(String(100))
-    passport_image_url = Column(Text)
-    selfie_image_url = Column(Text)
-    application_form_url = Column(Text)
-    priority = Column(Integer, default=0)
-    expected_submit_start = Column(Date)
-    expected_submit_end = Column(Date)
-    rules = Column(Text)
-    status = Column(Integer, default=0)
-    placeholder = Column(Integer, default=0)
-    appointment_datetime = Column(DateTime)
-    appointment_letter_url = Column(Text)
-    pnr_number = Column(String(100))
-    payment_link = Column(Text)
-    payment_help = Column(Integer, default=0)
-    note = Column(Text)

+ 24 - 5
app/models/payment.py

@@ -2,21 +2,31 @@ from sqlalchemy import Column, Integer, String, Text, DateTime, Enum, JSON, DECI
 from datetime import datetime
 from datetime import datetime
 from app.core.database import Base
 from app.core.database import Base
 
 
-
 class VasPayment(Base):
 class VasPayment(Base):
     __tablename__ = "vas_payment"
     __tablename__ = "vas_payment"
 
 
     id = Column(Integer, primary_key=True, autoincrement=True)
     id = Column(Integer, primary_key=True, autoincrement=True)
     order_id = Column(String(128), nullable=False)
     order_id = Column(String(128), nullable=False)
 
 
-    provider = Column(Enum('stripe','wechat','alipay'), nullable=False)
+    provider = Column(Enum('stripe', 'wechat', 'alipay'), nullable=False)
     channel = Column(String(50), nullable=False)
     channel = Column(String(50), nullable=False)
+    
+    # 支付意向ID (Stripe PaymentIntent ID)
     payment_intent_id = Column(String(255))
     payment_intent_id = Column(String(255))
+    # 外部交易号 (微信/支付宝的 transaction_id)
     external_trade_no = Column(String(255))
     external_trade_no = Column(String(255))
 
 
+    # --- 修改点 1: 扩充状态枚举 ---
     status = Column(
     status = Column(
-        Enum('pending','succeeded','failed','expired', 'late_paid'),
-        default='pending'
+        Enum(
+            'pending',              # 待支付
+            'succeeded',            # 支付成功
+            'failed',               # 支付失败
+            'expired',              # 支付超时
+            'late_paid',            # 逾期支付(极少见)
+            'refunded',             # 已全额退款
+        ),
+        default='pending',
     )
     )
     
     
     base_amount = Column(Integer, nullable=False)
     base_amount = Column(Integer, nullable=False)
@@ -33,6 +43,15 @@ class VasPayment(Base):
     expire_at = Column(DateTime)
     expire_at = Column(DateTime)
 
 
     provider_payload = Column(JSON)
     provider_payload = Column(JSON)
+    
+    # 外部退款单号 (Stripe Refund ID / 微信退款单号 / 支付宝退款单号)
+    external_refund_no = Column(String(255))
+    
+    # 退款时间
+    refunded_at = Column(DateTime)
+    
+    # 退款原因/备注
+    refund_reason = Column(String(255))
 
 
     created_at = Column(DateTime, default=datetime.utcnow)
     created_at = Column(DateTime, default=datetime.utcnow)
-    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

+ 18 - 0
app/models/payment_confirmation.py

@@ -0,0 +1,18 @@
+from sqlalchemy import Column, BigInteger, Integer, String, Enum, DateTime, func
+from app.core.database import Base
+
+
+class VasPaymentConfirmation(Base):
+    __tablename__ = "vas_payment_confirm"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True)
+    payment_id = Column(BigInteger, nullable=False)
+    amount = Column(Integer, nullable=False)
+    currency = Column(String(10), nullable=False)
+    random_offset = Column(Integer, nullable=False)
+    user_id = Column(String(128), nullable=False)
+    status = Column(Enum("pending","confirmed","ignored"), nullable=False, default="pending")
+    created_at = Column(DateTime, nullable=False, server_default=func.now())
+    confirmed_at = Column(DateTime, nullable=True)
+    admin_id = Column(String(128), nullable=True)
+    admin_confirmed_at = Column(DateTime, nullable=True)

+ 1 - 0
app/models/product_routing.py

@@ -12,6 +12,7 @@ class VasProductRouting(Base):
     routing_key = Column(String(255), nullable=False)
     routing_key = Column(String(255), nullable=False)
     script_version = Column(String(50), nullable=False)
     script_version = Column(String(50), nullable=False)
     is_active = Column(Integer, default=1)
     is_active = Column(Integer, default=1)
+    priority = Column(Integer, default=10)
 
 
     config = Column(JSON)
     config = Column(JSON)
 
 

+ 1 - 1
app/models/session.py

@@ -8,7 +8,7 @@ class VasSession(Base):
 
 
     id = Column(String(128), primary_key=True)   # token
     id = Column(String(128), primary_key=True)   # token
     user_id = Column(String(128), nullable=False)
     user_id = Column(String(128), nullable=False)
-    user_agent = Column(String(128), nullable=False)
+    user_agent = Column(String(255), nullable=False)
     ip = Column(String(128), nullable=False)
     ip = Column(String(128), nullable=False)
     expire_at = Column(DateTime, nullable=False)
     expire_at = Column(DateTime, nullable=False)
 
 

+ 2 - 4
app/models/email_verification.py → app/models/verification_token.py

@@ -3,12 +3,10 @@ from datetime import datetime
 from app.core.database import Base
 from app.core.database import Base
 
 
 
 
-class VasEmailVerification(Base):
-    __tablename__ = "vas_email_verification"
+class VasVerificationToken(Base):
+    __tablename__ = "vas_verification_token"
 
 
     id = Column(Integer, primary_key=True, autoincrement=True)
     id = Column(Integer, primary_key=True, autoincrement=True)
-    user_id = Column(String(64), nullable=False)
-    email = Column(String(255), nullable=False)
     token = Column(String(128), nullable=False)
     token = Column(String(128), nullable=False)
     used = Column(Integer, default=0)
     used = Column(Integer, default=0)
     expire_at = Column(DateTime, nullable=False)
     expire_at = Column(DateTime, nullable=False)

+ 1 - 2
app/schemas/auth.py

@@ -4,8 +4,7 @@ from app.schemas.common import ApiResponse
 from app.schemas.user import VasUserOut
 from app.schemas.user import VasUserOut
 
 
 class AutoRegisterRequest(BaseModel):
 class AutoRegisterRequest(BaseModel):
-    user_agent: Optional[str] = None
-    register_ip: str
+    pass
 
 
 class AutoRegisterData(BaseModel):
 class AutoRegisterData(BaseModel):
     user: VasUserOut
     user: VasUserOut

+ 0 - 52
app/schemas/auto_booking.py

@@ -1,52 +0,0 @@
-from pydantic import BaseModel
-from typing import Optional
-from datetime import date, datetime
-
-class AutoBookingBase(BaseModel):
-    provider: Optional[str] = None
-    visa_center: Optional[str] = None
-    order_no: Optional[str] = None
-    social_account: Optional[str] = None
-    account: Optional[str] = None
-    password: Optional[str] = None
-    last_name: Optional[str] = None
-    first_name: Optional[str] = None
-    gender: Optional[str] = None
-    birthday: Optional[date] = None
-    email: Optional[str] = None
-    alias_email: Optional[str] = None
-    phone_country_code: Optional[str] = None
-    phone_no: Optional[str] = None
-    passport_no: Optional[str] = None
-    nationality: Optional[str] = None
-    passport_expiry_date: Optional[date] = None
-    address_line1: Optional[str] = None
-    address_line2: Optional[str] = None
-    state: Optional[str] = None
-    city: Optional[str] = None
-    postcode: Optional[str] = None
-    travel_date: Optional[date] = None
-    cover_letter: Optional[str] = None
-    passport_image_url: Optional[str] = None
-    selfie_image_url: Optional[str] = None
-    application_form_url: Optional[str] = None
-    priority: Optional[int] = None
-    expected_submit_start: Optional[date] = None
-    expected_submit_end: Optional[date] = None
-    rules: Optional[str] = None
-    status: Optional[int] = None
-    placeholder: Optional[int] = None
-    appointment_datetime: Optional[datetime] = None
-    appointment_letter_url: Optional[str] = None
-    pnr_number: Optional[str] = None
-    payment_link: Optional[str] = None
-    payment_help: Optional[int] = None
-    note: Optional[str] = None
-
-class AutoBookingCreate(AutoBookingBase):
-    pass
-
-class AutoBookingOut(AutoBookingBase):
-    id: int
-    class Config:
-        orm_mode = True

+ 4 - 4
app/schemas/payment.py

@@ -5,7 +5,7 @@ from datetime import datetime
 
 
 
 
 class VasPaymentBase(BaseModel):
 class VasPaymentBase(BaseModel):
-    status: Optional[Literal['pending', 'succeeded', 'failed', 'expired', 'late_paid']] = None
+    status: Optional[Literal['pending', 'succeeded', 'failed', 'expired', 'late_paid', 'refunded']] = None
 
 
     qr_id: Optional[int] = None
     qr_id: Optional[int] = None
     payment_url: Optional[str] = None
     payment_url: Optional[str] = None
@@ -28,9 +28,6 @@ class VasPaymentCreate(BaseModel):
     order_id: str
     order_id: str
     provider: Literal['stripe', 'wechat', 'alipay']
     provider: Literal['stripe', 'wechat', 'alipay']
 
 
-class VasPaymentUpdate(VasPaymentBase):
-    pass
-
 class VasPaymentOut(VasPaymentBase):
 class VasPaymentOut(VasPaymentBase):
     id: int
     id: int
     order_id: str
     order_id: str
@@ -50,6 +47,9 @@ class VasPaymentOut(VasPaymentBase):
     
     
     exchange_rate: float  # 注意:仅用于展示,DB 里是 DECIMAL
     exchange_rate: float  # 注意:仅用于展示,DB 里是 DECIMAL
 
 
+    external_refund_no: Optional[str]
+    refund_reason: Optional[str]
+    refunded_at: Optional[datetime]
     created_at: datetime
     created_at: datetime
     updated_at: datetime
     updated_at: datetime
 
 

+ 30 - 0
app/schemas/payment_confirmation.py

@@ -0,0 +1,30 @@
+from pydantic import BaseModel
+from typing import Optional, Dict, Any, Literal, List
+from datetime import datetime
+
+class VasPaymentConfirmationBase(BaseModel):
+    payment_id: int
+    amount: int
+    currency: str
+    random_offset: int
+    confirmed_at: datetime
+
+class VasPaymentConfirmationCreate(VasPaymentConfirmationBase):
+    pass
+
+class VasPaymentConfirmationUpdate(BaseModel):
+    status: Optional[Literal['confirmed', 'ignored']] = None
+    admin_id: Optional[str] = None
+    admin_confirmed_at: Optional[datetime] = None
+
+class VasPaymentConfirmationOut(VasPaymentConfirmationBase):
+    id: int
+    user_id: str
+    status: str
+    created_at: datetime
+    admin_id: Optional[str] = None
+    admin_confirmed_at: Optional[datetime] = None
+
+    model_config = {
+        "from_attributes": True
+    }

+ 2 - 1
app/schemas/product_routing.py

@@ -22,6 +22,7 @@ class VasProductRoutingCreate(BaseModel):
     product_id: int
     product_id: int
     routing_key: str
     routing_key: str
     script_version: str
     script_version: str
+    priority: int
     config: Dict[str, Any]
     config: Dict[str, Any]
 
 
 class VasProductRoutingUpdate(VasProductRoutingBase):
 class VasProductRoutingUpdate(VasProductRoutingBase):
@@ -30,7 +31,7 @@ class VasProductRoutingUpdate(VasProductRoutingBase):
 class VasProductRoutingOut(VasProductRoutingBase):
 class VasProductRoutingOut(VasProductRoutingBase):
     id: int
     id: int
     product_id: int
     product_id: int
-
+    priority: int
     routing_key: str
     routing_key: str
     script_version: str
     script_version: str
 
 

+ 191 - 133
app/services/auth_service.py

@@ -1,27 +1,56 @@
-import uuid, bcrypt, random, string
-from sqlalchemy.orm import Session
+# app/services/auth_service.py
+
+import uuid
+import bcrypt
+import random
+import string
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
+from typing import Dict
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
 from redis.asyncio import Redis
 from redis.asyncio import Redis
-from app.utils.redis_utils import redis_qpush
-from app.core.auth import get_current_user
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from app.core.biz_exception import (
+    NotFoundError,
+    PermissionDeniedError,
+    BizLogicError,
+)
 from app.models.user import VasUser
 from app.models.user import VasUser
 from app.models.session import VasSession
 from app.models.session import VasSession
-from app.models.email_verification import VasEmailVerification
-from app.schemas.auth import AutoRegisterRequest, SendBindCodeRequest, SendResetCodeRequest, BindEmailRequest, ResetPasswordRequest, LoginRequest
+from app.models.verification_token import VasVerificationToken
+from app.schemas.auth import (
+    AutoRegisterRequest,
+    SendBindCodeRequest,
+    SendResetCodeRequest,
+    BindEmailRequest,
+    ResetPasswordRequest,
+    LoginRequest,
+)
 from app.services.notification_service import NotificationService
 from app.services.notification_service import NotificationService
 
 
 
 
-def _random_password(length=16):
-    return ''.join(random.choices(string.ascii_letters + string.digits + "!@#$%", k=length))
+def _random_password(length: int = 16) -> str:
+    return "".join(
+        random.choices(
+            string.ascii_letters + string.digits + "!@#$%",
+            k=length,
+        )
+    )
+
 
 
 class AuthService:
 class AuthService:
-    # -----------------------
-    # 自动注册
-    # -----------------------
+    # =========================
+    # 自动注册(游客)
+    # =========================
     @staticmethod
     @staticmethod
-    def auto_register(db: Session, req:AutoRegisterRequest):
-        uid = f'usr-{uuid.uuid4().hex[:8]}'
+    async def auto_register(
+        db: AsyncSession,
+        req: AutoRegisterRequest,
+        ip: str = None,
+        user_agent = None
+    ) -> Dict:
+        uid = f"usr-{uuid.uuid4().hex[:8]}"
 
 
         user = VasUser(
         user = VasUser(
             id=uid,
             id=uid,
@@ -29,197 +58,226 @@ class AuthService:
             nickname="anonymous visitor",
             nickname="anonymous visitor",
             preferred_language="en",
             preferred_language="en",
             timezone="Asia/Shanghai",
             timezone="Asia/Shanghai",
-            register_ip=req.register_ip,
+            register_ip=ip or '',
         )
         )
         db.add(user)
         db.add(user)
-        db.commit()
-
-        # 创建 session
-        token = f"tok_{uuid.uuid4().hex}"
 
 
+        token = "tok_" + uuid.uuid4().hex
         session = VasSession(
         session = VasSession(
             id=token,
             id=token,
             user_id=uid,
             user_id=uid,
-            user_agent=req.user_agent or "",
-            ip=req.register_ip,
-            expire_at=datetime.utcnow() + timedelta(days=7)
+            user_agent=user_agent or '',
+            ip=ip or '',
+            expire_at=datetime.utcnow() + timedelta(days=7),
         )
         )
         db.add(session)
         db.add(session)
-        db.commit()
-        return {
-            "user": user,
-            "token": token
-        }
-        
-    def send_bind_code(db: Session, payload: SendBindCodeRequest, auth_user: VasUser, redis_client:Redis):
-        token = uuid.uuid4().hex[0:6]
-        record = VasEmailVerification(
-            user_id=auth_user.id,
-            email=payload.email,
+
+        await db.commit()
+        await db.refresh(user)
+
+        return {"user": user, "token": token}
+
+    # =========================
+    # 发送绑定邮箱验证码
+    # =========================
+    @staticmethod
+    async def send_bind_code(
+        db: AsyncSession,
+        payload: SendBindCodeRequest,
+        auth_user: VasUser,
+        redis_client: Redis,
+    ):
+        token = uuid.uuid4().hex[:6]
+
+        record = VasVerificationToken(
             token=token,
             token=token,
-            expire_at=datetime.utcnow() + timedelta(minutes=30)
+            expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         )
         db.add(record)
         db.add(record)
-        db.commit()
-        
-        print(f"📧 send verification email token={token}")
-        NotificationService.create(
+        await db.commit()
+
+        await NotificationService.create(
             redis_client=redis_client,
             redis_client=redis_client,
-            ntype="email verification email",
+            ntype="email_verification",
             user_id=auth_user.id,
             user_id=auth_user.id,
             channels=["email"],
             channels=["email"],
             template_id="email_verification_for_bind",
             template_id="email_verification_for_bind",
-            payload={
-                "token": token
-            }
+            payload={"token": token},
         )
         )
-        
-    def send_reset_code(db: Session, payload: SendResetCodeRequest, redis_client:Redis):
-        user = db.query(VasUser).filter(
+
+    # =========================
+    # 发送重置密码验证码
+    # =========================
+    @staticmethod
+    async def send_reset_code(
+        db: AsyncSession,
+        payload: SendResetCodeRequest,
+        redis_client: Redis,
+    ):
+        stmt = select(VasUser).where(
             VasUser.email == payload.email,
             VasUser.email == payload.email,
-            VasUser.email_verified == 1
-        ).first()
+            VasUser.email_verified == 1,
+        )
+        user = (await db.execute(stmt)).scalar_one_or_none()
+
         if not user:
         if not user:
             raise BizLogicError("User not exist")
             raise BizLogicError("User not exist")
-        
-        token = uuid.uuid4().hex[0:6]
-        record = VasEmailVerification(
-            user_id=user.id,
-            email=payload.email,
+
+        token = uuid.uuid4().hex[:6]
+        record = VasVerificationToken(
             token=token,
             token=token,
-            expire_at=datetime.utcnow() + timedelta(minutes=30)
+            expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         )
         db.add(record)
         db.add(record)
-        db.commit()
-        
-        print(f"📧 send verification email token={token}")
-        NotificationService.create(
+        await db.commit()
+
+        await NotificationService.create(
             redis_client=redis_client,
             redis_client=redis_client,
-            ntype="email verification email",
+            ntype="email_verification",
             user_id=user.id,
             user_id=user.id,
             channels=["email"],
             channels=["email"],
             template_id="email_verification_for_reset",
             template_id="email_verification_for_reset",
-            payload={
-                "token": token
-            }
+            payload={"token": token},
         )
         )
-    # -----------------------
+
+    # =========================
     # 绑定邮箱
     # 绑定邮箱
-    # -----------------------
+    # =========================
     @staticmethod
     @staticmethod
-    def bind_email(db: Session, payload: BindEmailRequest, auth_user: VasUser, redis_client:Redis):
-        user = db.query(VasUser).filter(
+    async def bind_email(
+        db: AsyncSession,
+        payload: BindEmailRequest,
+        auth_user: VasUser,
+        redis_client: Redis,
+        ip: str = None,
+        user_agent = None
+    ) -> Dict:
+        # 邮箱是否已被绑定
+        stmt = select(VasUser).where(
             VasUser.email == payload.email,
             VasUser.email == payload.email,
-            VasUser.email_verified == 1
-        ).first()
-        if user:
-            raise BizLogicError("Email has been bound")
-        
-        record = (
-            db.query(VasEmailVerification)
-            .filter_by(token=payload.code, used=0)
-            .first()
+            VasUser.email_verified == 1,
+        )
+        if (await db.execute(stmt)).scalar_one_or_none():
+            raise BizLogicError("Email already bound")
+
+        # 校验验证码
+        stmt = select(VasVerificationToken).where(
+            VasVerificationToken.token == payload.code,
+            VasVerificationToken.used == 0,
         )
         )
+        record = (await db.execute(stmt)).scalar_one_or_none()
         if not record:
         if not record:
             raise BizLogicError("Token invalid")
             raise BizLogicError("Token invalid")
 
 
         if record.expire_at < datetime.utcnow():
         if record.expire_at < datetime.utcnow():
             raise BizLogicError("Token expired")
             raise BizLogicError("Token expired")
- 
-        # 更新 user.email
-        user = db.query(VasUser).filter_by(id=record.user_id).first()
-        user.email = payload.email
 
 
-        # 随机密码
-        plain = _random_password()
-        hashed = bcrypt.hashpw(plain.encode(), bcrypt.gensalt()).decode()
-        user.password_hash = hashed
+        user = await db.get(VasUser, auth_user.id)
+
+        plain_pwd = _random_password()
+        hashed_pwd = bcrypt.hashpw(
+            plain_pwd.encode(),
+            bcrypt.gensalt(),
+        ).decode()
+
+        user.email = payload.email
+        user.password_hash = hashed_pwd
         user.email_verified = 1
         user.email_verified = 1
         record.used = 1
         record.used = 1
-        
-        # 创建 session
-        session_id = "tok_" + uuid.uuid4().hex
 
 
+        token = "tok_" + uuid.uuid4().hex
         session = VasSession(
         session = VasSession(
-            id=session_id,
+            id=token,
             user_id=user.id,
             user_id=user.id,
-            user_agent="",
-            ip="",
-            expire_at=datetime.utcnow() + timedelta(days=30)
+            ip=ip or '',
+            user_agent=user_agent or '',
+            expire_at=datetime.utcnow() + timedelta(days=30),
         )
         )
         db.add(session)
         db.add(session)
-        db.commit()
-        db.refresh(user)
-        
-        print(f"📧 send login email and password")
-        NotificationService.create(
+
+        await db.commit()
+        await db.refresh(user)
+
+        await NotificationService.create(
             redis_client=redis_client,
             redis_client=redis_client,
-            ntype="login credentials",
+            ntype="login_credentials",
             user_id=user.id,
             user_id=user.id,
             channels=["email"],
             channels=["email"],
             template_id="login_credentials",
             template_id="login_credentials",
             payload={
             payload={
                 "username": payload.email,
                 "username": payload.email,
-                "password": plain
-            }
+                "password": plain_pwd,
+            },
         )
         )
-        
-        return {
-            "user": user,
-            "token": session_id
-        }
-    
-    def reset_password(db: Session, payload: ResetPasswordRequest):
-        user = db.query(VasUser).filter(
+
+        return {"user": user, "token": token}
+
+    # =========================
+    # 重置密码
+    # =========================
+    @staticmethod
+    async def reset_password(
+        db: AsyncSession,
+        payload: ResetPasswordRequest,
+    ) -> bool:
+        stmt = select(VasUser).where(
             VasUser.email == payload.email,
             VasUser.email == payload.email,
-            VasUser.email_verified == 1
-        ).first()
+            VasUser.email_verified == 1,
+        )
+        user = (await db.execute(stmt)).scalar_one_or_none()
         if not user:
         if not user:
             raise BizLogicError("User not exist")
             raise BizLogicError("User not exist")
-        
-        record = (
-            db.query(VasEmailVerification)
-            .filter_by(token=payload.code, used=0)
-            .first()
+
+        stmt = select(VasVerificationToken).where(
+            VasVerificationToken.token == payload.code,
+            VasVerificationToken.used == 0,
         )
         )
+        record = (await db.execute(stmt)).scalar_one_or_none()
         if not record:
         if not record:
             raise BizLogicError("Token invalid")
             raise BizLogicError("Token invalid")
 
 
         if record.expire_at < datetime.utcnow():
         if record.expire_at < datetime.utcnow():
             raise BizLogicError("Token expired")
             raise BizLogicError("Token expired")
-        
-        hashed = bcrypt.hashpw(payload.new_password.encode(), bcrypt.gensalt()).decode()
-        user.password_hash = hashed
+
+        user.password_hash = bcrypt.hashpw(
+            payload.new_password.encode(),
+            bcrypt.gensalt(),
+        ).decode()
         record.used = 1
         record.used = 1
-        db.commit()
+
+        await db.commit()
         return True
         return True
-    # -----------------------
-    # 用户登录
-    # -----------------------
+
+    # =========================
+    # 登录
+    # =========================
     @staticmethod
     @staticmethod
-    def login(db: Session, req:LoginRequest):
-        user = db.query(VasUser).filter_by(email=req.email).first()
+    async def login(
+        db: AsyncSession,
+        req: LoginRequest,
+        ip: str = None,
+        user_agent: str = None
+    ) -> Dict:
+        stmt = select(VasUser).where(VasUser.email == req.email)
+        user = (await db.execute(stmt)).scalar_one_or_none()
         if not user:
         if not user:
             raise NotFoundError("User not found")
             raise NotFoundError("User not found")
 
 
-        # 对比密码
-        if not bcrypt.checkpw(req.password.encode(), user.password_hash.encode()):
+        if not bcrypt.checkpw(
+            req.password.encode(),
+            user.password_hash.encode(),
+        ):
             raise PermissionDeniedError("Password incorrect")
             raise PermissionDeniedError("Password incorrect")
 
 
-        # 创建 session
         token = "tok_" + uuid.uuid4().hex
         token = "tok_" + uuid.uuid4().hex
-
         session = VasSession(
         session = VasSession(
             id=token,
             id=token,
             user_id=user.id,
             user_id=user.id,
-            user_agent="",
-            ip="",
-            expire_at=datetime.utcnow() + timedelta(days=7)
+            user_agent=user_agent or "",
+            ip=ip or "",
+            expire_at=datetime.utcnow() + timedelta(days=7),
         )
         )
         db.add(session)
         db.add(session)
-        db.commit()
+        await db.commit()
 
 
-        return {
-            "user": user,
-            "token": token
-        }
+        return {"user": user, "token": token}

+ 0 - 72
app/services/auto_booking_service.py

@@ -1,72 +0,0 @@
-from sqlalchemy.orm import Session
-from sqlalchemy import func
-from app.models.auto_booking import AutoBooking
-from app.schemas.auto_booking import AutoBookingCreate
-from typing import List
-
-class AutoBookingService:
-
-    @staticmethod
-    def create(db: Session, obj_in: AutoBookingCreate) -> AutoBooking:
-        db_obj = AutoBooking(**obj_in.dict())
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
-        return db_obj
-
-    @staticmethod
-    def get_by_id(db: Session, id: int):
-        return db.query(AutoBooking).filter(AutoBooking.id == id).first()
-
-    @staticmethod
-    def delete_by_id(db: Session, id: int):
-        obj = db.query(AutoBooking).filter(AutoBooking.id == id).first()
-        if obj:
-            db.delete(obj)
-            db.commit()
-            return True
-        return False
-
-    @staticmethod
-    def update_by_id(db: Session, id: int, updated_data: dict):
-        obj = db.query(AutoBooking).filter(AutoBooking.id == id).first()
-        if not obj:
-            return None
-        for key, value in updated_data.items():
-            setattr(obj, key, value)
-        db.commit()
-        db.refresh(obj)
-        return obj
-
-    @staticmethod
-    def get_paginated(db: Session, tech_provider: str, keyword: str, page: int, size: int):
-        query = db.query(AutoBooking)
-        if tech_provider:
-            query = query.filter(AutoBooking.provider == tech_provider)
-        if keyword:
-            like_str = f"%{keyword}%"
-            query = query.filter(
-                (AutoBooking.first_name.like(like_str)) |
-                (AutoBooking.last_name.like(like_str)) |
-                (AutoBooking.email.like(like_str)) |
-                (AutoBooking.visa_center.like(like_str))
-            )
-        return query.offset(page * size).limit(size).all()
-
-    @staticmethod
-    def batch_get_by_ids(db: Session, ids: List[int]):
-        return db.query(AutoBooking).filter(AutoBooking.id.in_(ids)).all()
-
-    @staticmethod
-    def statistics(db: Session, tech_provider: str):
-        query = db.query(AutoBooking.provider, func.count(AutoBooking.id)).group_by(AutoBooking.provider)
-        if tech_provider:
-            query = query.filter(AutoBooking.provider == tech_provider)
-        return query.all()
-
-    @staticmethod
-    def get_pending(db: Session, tech_provider: str):
-        query = db.query(AutoBooking).filter(AutoBooking.status == 0)
-        if tech_provider:
-            query = query.filter(AutoBooking.provider == tech_provider)
-        return query.all()

+ 37 - 16
app/services/card_service.py

@@ -1,30 +1,51 @@
-from sqlalchemy.orm import Session
-from sqlalchemy import text
-from typing import List, Optional
-from app.utils.search import apply_keyword_search
-from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+# app/services/card_service.py
+
+from typing import Optional
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
 from app.models.card import Card
 from app.models.card import Card
 from app.schemas.card import CardCreate
 from app.schemas.card import CardCreate
+from app.utils.search import apply_keyword_search_stmt
+from app.utils.pagination import paginate
+from app.core.biz_exception import BizLogicError
 
 
 
 
 class CardService:
 class CardService:
+
+    # =========================
+    # 创建 Card
+    # =========================
     @staticmethod
     @staticmethod
-    def create(db: Session, obj_in: CardCreate) -> Card:
+    async def create(
+        db: AsyncSession,
+        obj_in: CardCreate,
+    ) -> Card:
         db_obj = Card(**obj_in.dict())
         db_obj = Card(**obj_in.dict())
         db.add(db_obj)
         db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
         return db_obj
 
 
+    # =========================
+    # 关键字分页查询
+    # =========================
     @staticmethod
     @staticmethod
-    def list_by_keyword(db: Session, keyword: str = None, page: int = 0, size: int = 10, culture: str = "english"):
-        query = db.query(Card).filter(Card.culture == culture)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_by_keyword(
+        db: AsyncSession,
+        keyword: Optional[str] = None,
+        page: int = 0,
+        size: int = 10,
+        culture: str = "english",
+    ):
+        stmt = select(Card).where(Card.culture == culture)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=Card,
             model=Card,
             keyword=keyword,
             keyword=keyword,
-            fields=["title", "content", "label"]
+            fields=["title", "content", "label"],
         )
         )
-        return paginate(query, page, size)
+
+        return await paginate(db, stmt, page, size)

+ 85 - 24
app/services/configuration_service.py

@@ -1,50 +1,111 @@
-from sqlalchemy.orm import Session
+# app/services/configuration_service.py
+
 from typing import List, Optional
 from typing import List, Optional
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.configuration import Configuration
 from app.models.configuration import Configuration
 from app.schemas.configuration import ConfigurationCreate, ConfigurationUpdate
 from app.schemas.configuration import ConfigurationCreate, ConfigurationUpdate
 
 
 
 
 class ConfigurationService:
 class ConfigurationService:
+
+    # =========================
+    # 创建配置
+    # =========================
     @staticmethod
     @staticmethod
-    def create(db: Session, config_in: ConfigurationCreate) -> Configuration:
-        config = db.query(Configuration).filter(Configuration.config_key == config_in.config_key).first()
-        if config:
-            raise BizLogicError(f"Config Key '{config_in.config_key}' already exist")
+    async def create(
+        db: AsyncSession,
+        config_in: ConfigurationCreate,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_in.config_key
+        )
+        existing = (await db.execute(stmt)).scalar_one_or_none()
+        if existing:
+            raise BizLogicError(
+                f"Config Key '{config_in.config_key}' already exist"
+            )
+
         db_obj = Configuration(**config_in.dict())
         db_obj = Configuration(**config_in.dict())
         db.add(db_obj)
         db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
         return db_obj
 
 
+    # =========================
+    # 获取全部配置
+    # =========================
     @staticmethod
     @staticmethod
-    def get_all(db: Session) -> List[Configuration]:
-        return db.query(Configuration).order_by(Configuration.id.desc()).all()
+    async def get_all(
+        db: AsyncSession,
+    ) -> List[Configuration]:
+        stmt = select(Configuration).order_by(Configuration.id.desc())
+        result = await db.execute(stmt)
+        return result.scalars().all()
 
 
+    # =========================
+    # 根据 key 获取配置
+    # =========================
     @staticmethod
     @staticmethod
-    def get_by_key(db: Session, config_key: str) -> Optional[Configuration]:
-        config = db.query(Configuration).filter(Configuration.config_key == config_key).first()
+    async def get_by_key(
+        db: AsyncSession,
+        config_key: str,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_key
+        )
+        config = (await db.execute(stmt)).scalar_one_or_none()
         if not config:
         if not config:
-            raise NotFoundError(f"Config Key '{config_key}' not exist")
+            raise NotFoundError(
+                f"Config Key '{config_key}' not exist"
+            )
         return config
         return config
 
 
+    # =========================
+    # 根据 key 更新配置
+    # =========================
     @staticmethod
     @staticmethod
-    def update_by_key(db: Session, config_key: str, config_in: ConfigurationUpdate) -> Optional[Configuration]:
-        db_obj = db.query(Configuration).filter(Configuration.config_key == config_key).first()
+    async def update_by_key(
+        db: AsyncSession,
+        config_key: str,
+        config_in: ConfigurationUpdate,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_key
+        )
+        db_obj = (await db.execute(stmt)).scalar_one_or_none()
         if not db_obj:
         if not db_obj:
-            raise NotFoundError(f"Config Key '{config_key}' not exist")
+            raise NotFoundError(
+                f"Config Key '{config_key}' not exist"
+            )
+
         for field, value in config_in.dict(exclude_unset=True).items():
         for field, value in config_in.dict(exclude_unset=True).items():
             setattr(db_obj, field, value)
             setattr(db_obj, field, value)
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
         return db_obj
 
 
+    # =========================
+    # 根据 key 删除配置
+    # =========================
     @staticmethod
     @staticmethod
-    def delete_by_key(db: Session, config_key: str) -> Optional[Configuration]:
-        db_obj = db.query(Configuration).filter(Configuration.config_key == config_key).first()
+    async def delete_by_key(
+        db: AsyncSession,
+        config_key: str,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_key
+        )
+        db_obj = (await db.execute(stmt)).scalar_one_or_none()
         if not db_obj:
         if not db_obj:
-            raise NotFoundError(f"Config Key '{config_key}' not exist")
-        db.delete(db_obj)
-        db.commit()
+            raise NotFoundError(
+                f"Config Key '{config_key}' not exist"
+            )
+
+        await db.delete(db_obj)
+        await db.commit()
         return db_obj
         return db_obj

Разница между файлами не показана из-за своего большого размера
+ 469 - 441
app/services/email_authorizations_service.py


+ 57 - 21
app/services/http_session_service.py

@@ -1,43 +1,79 @@
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from typing import Optional
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, delete
+
+from app.core.biz_exception import NotFoundError
 from app.models.http_session import HttpSession
 from app.models.http_session import HttpSession
 from app.schemas.http_session import HttpSessionCreate, HttpSessionUpdate
 from app.schemas.http_session import HttpSessionCreate, HttpSessionUpdate
-from typing import Optional
+
 
 
 class HttpSessionService:
 class HttpSessionService:
 
 
+    # ============================
+    # 创建 Session
+    # ============================
     @staticmethod
     @staticmethod
-    def create(db: Session, data: HttpSessionCreate) -> HttpSession:
+    async def create(db: AsyncSession, data: HttpSessionCreate) -> HttpSession:
         obj = HttpSession(**data.dict())
         obj = HttpSession(**data.dict())
         db.add(obj)
         db.add(obj)
-        db.commit()
-        db.refresh(obj)
+        await db.commit()
+        await db.refresh(obj)
         return obj
         return obj
 
 
+    # ============================
+    # 根据 session_id 获取
+    # ============================
     @staticmethod
     @staticmethod
-    def get_by_sid(db: Session, session_id: str) -> Optional[HttpSession]:
-        obj = db.query(HttpSession).filter(HttpSession.session_id == session_id).first()
+    async def get_by_sid(
+        db: AsyncSession,
+        session_id: str
+    ) -> HttpSession:
+        stmt = select(HttpSession).where(HttpSession.session_id == session_id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("Session not found")
             raise NotFoundError("Session not found")
+
         return obj
         return obj
 
 
+    # ============================
+    # 根据 session_id 删除
+    # ============================
     @staticmethod
     @staticmethod
-    def delete_by_sid(db: Session, session_id: str) -> bool:
-        obj = db.query(HttpSession).filter(HttpSession.session_id == session_id).first()
-        if not obj:
+    async def delete_by_sid(
+        db: AsyncSession,
+        session_id: str
+    ) -> bool:
+        stmt = delete(HttpSession).where(HttpSession.session_id == session_id)
+        result = await db.execute(stmt)
+
+        if result.rowcount == 0:
             raise NotFoundError("Session not found")
             raise NotFoundError("Session not found")
-        db.delete(obj)
-        db.commit()
-        return obj
 
 
+        await db.commit()
+        return True
+
+    # ============================
+    # 根据 session_id 更新
+    # ============================
     @staticmethod
     @staticmethod
-    def update_by_sid(db: Session, session_id: str, data: HttpSessionUpdate):
-        obj = db.query(HttpSession).filter(HttpSession.session_id == session_id).first()
+    async def update_by_sid(
+        db: AsyncSession,
+        session_id: str,
+        data: HttpSessionUpdate
+    ) -> HttpSession:
+        stmt = select(HttpSession).where(HttpSession.session_id == session_id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("Session not found")
             raise NotFoundError("Session not found")
-        for k, v in data.dict().items():
-            if v is not None:
-                setattr(obj, k, v)
-        db.commit()
-        db.refresh(obj)
+
+        for field, value in data.dict(exclude_unset=True).items():
+            setattr(obj, field, value)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj
         return obj

+ 18 - 8
app/services/notification_service.py

@@ -1,15 +1,25 @@
-# app/services/product_service.py
+# app/services/notification_service.py
+
 import uuid
 import uuid
-from sqlalchemy.orm import Session
-from typing import Optional, List, Dict
+from typing import List, Dict, Any
+
 from redis.asyncio import Redis
 from redis.asyncio import Redis
-from app.utils.redis_utils import redis_qpush
+from app.utils.redis_utils import redis_qpush, redis_qpop
+
 
 
 class NotificationService:
 class NotificationService:
 
 
-    def create(redis_client: Redis, ntype: str, user_id:str, channels:List[str], template_id=str, payload=Dict):
+    @staticmethod
+    async def create(
+        redis_client: Redis,
+        ntype: str,
+        user_id: str,
+        channels: List[str],
+        template_id: str,
+        payload: Dict[str, Any]
+    ) -> None:
         notification_payload = {
         notification_payload = {
-            "notification_id": f'nid_{uuid.uuid4().hex}',
+            "notification_id": f"nid_{uuid.uuid4().hex}",
             "type": ntype,
             "type": ntype,
             "user_id": user_id,
             "user_id": user_id,
             "channels": channels,
             "channels": channels,
@@ -17,10 +27,10 @@ class NotificationService:
             "payload": payload
             "payload": payload
         }
         }
 
 
-        redis_qpush(
+        await redis_qpush(
             redis_client,
             redis_client,
             "vas_notification_queue",
             "vas_notification_queue",
             notification_payload
             notification_payload
         )
         )
 
 
-
+    

+ 128 - 126
app/services/order_service.py

@@ -1,32 +1,42 @@
 # app/services/order_service.py
 # app/services/order_service.py
 import uuid
 import uuid
 import json
 import json
-from redis.asyncio import Redis
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from sqlalchemy.orm import Session
-from typing import List
-from app.utils.search import apply_keyword_search
+from typing import List, Optional
+
+from redis.asyncio import Redis
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
-from app.core.auth import get_current_user
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.user import VasUser
 from app.models.user import VasUser
 from app.models.order import VasOrder
 from app.models.order import VasOrder
 from app.models.vas_task import VasTask
 from app.models.vas_task import VasTask
 from app.models.product import VasProduct
 from app.models.product import VasProduct
 from app.models.product_routing import VasProductRouting
 from app.models.product_routing import VasProductRouting
 from app.schemas.order import VasOrderCreate, VasOrderPatchUserInputs
 from app.schemas.order import VasOrderCreate, VasOrderPatchUserInputs
-from app.services.notification_service import NotificationService
+
 
 
 class OrderService:
 class OrderService:
-    
+
+    # --------------------------------------------------
+    # 管理员强制标记为已支付
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def mark_as_admin_paid(db: Session, order: VasOrder, admin_user):
+    async def mark_as_admin_paid(
+        db: AsyncSession,
+        order: VasOrder,
+        admin_user: VasUser,
+    ) -> VasOrder:
+
         if order.status == "paid":
         if order.status == "paid":
             return order
             return order
 
 
         order.status = "paid"
         order.status = "paid"
 
 
-        # ===== 核心修复点 =====
+        # ===== user_inputs 安全修复 =====
         raw_inputs = order.user_inputs
         raw_inputs = order.user_inputs
 
 
         if isinstance(raw_inputs, str):
         if isinstance(raw_inputs, str):
@@ -34,12 +44,9 @@ class OrderService:
                 order.user_inputs = json.loads(raw_inputs)
                 order.user_inputs = json.loads(raw_inputs)
             except Exception:
             except Exception:
                 order.user_inputs = {}
                 order.user_inputs = {}
-        elif raw_inputs is None:
-            order.user_inputs = {}
-        elif not isinstance(raw_inputs, dict):
+        elif raw_inputs is None or not isinstance(raw_inputs, dict):
             order.user_inputs = {}
             order.user_inputs = {}
 
 
-        # 记录绕过支付的原因(非常重要)
         order.user_inputs["_admin_bypass"] = {
         order.user_inputs["_admin_bypass"] = {
             "enabled": True,
             "enabled": True,
             "by": admin_user.id,
             "by": admin_user.id,
@@ -48,55 +55,50 @@ class OrderService:
         }
         }
 
 
         db.add(order)
         db.add(order)
-        db.commit()
-        db.refresh(order)
+        await db.commit()
+        await db.refresh(order)
 
 
         return order
         return order
-    
+
+    # --------------------------------------------------
+    # 为订单创建任务(幂等)
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def create_tasks_for_order(db: Session, order: VasOrder):
-        """
-        为已支付订单创建任务(幂等)
-        """
+    async def create_tasks_for_order(
+        db: AsyncSession,
+        order: VasOrder
+    ) -> List[VasTask]:
+
         if order.status != "paid":
         if order.status != "paid":
             return []
             return []
 
 
-        # ---------- 1. 查 routing ----------
-        routings = (
-            db.query(VasProductRouting)
-            .filter(
-                VasProductRouting.product_id == order.product_id,
-                VasProductRouting.is_active == 1
-            )
-            .all()
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == order.product_id,
+            VasProductRouting.is_active == 1
         )
         )
+        result = await db.execute(stmt)
+        routings = result.scalars().all()
 
 
         if not routings:
         if not routings:
             return []
             return []
 
 
-        created_tasks = []
+        created_tasks: List[VasTask] = []
 
 
         for routing in routings:
         for routing in routings:
-
-            # ---------- 2. 幂等判断 ----------
-            exists = (
-                db.query(VasTask)
-                .filter(
-                    VasTask.order_id == order.id,
-                    VasTask.routing_key == routing.routing_key,
-                    VasTask.script_version == routing.script_version,
-                )
-                .first()
+            exists_stmt = select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.routing_key == routing.routing_key,
+                VasTask.script_version == routing.script_version,
             )
             )
+            exists = (await db.execute(exists_stmt)).scalar_one_or_none()
             if exists:
             if exists:
                 continue
                 continue
 
 
-            # ---------- 3. 创建 task ----------
             task = VasTask(
             task = VasTask(
                 order_id=order.id,
                 order_id=order.id,
                 routing_key=routing.routing_key,
                 routing_key=routing.routing_key,
                 script_version=routing.script_version,
                 script_version=routing.script_version,
-                priority=10,
+                priority=routing.priority,
                 status="pending",
                 status="pending",
                 user_inputs=order.user_inputs,
                 user_inputs=order.user_inputs,
                 config=routing.config,
                 config=routing.config,
@@ -105,113 +107,113 @@ class OrderService:
                 expire_at=datetime.utcnow() + timedelta(days=7),
                 expire_at=datetime.utcnow() + timedelta(days=7),
                 created_at=datetime.utcnow(),
                 created_at=datetime.utcnow(),
             )
             )
-
             db.add(task)
             db.add(task)
             created_tasks.append(task)
             created_tasks.append(task)
 
 
-        db.commit()
-
+        await db.commit()
         return created_tasks
         return created_tasks
-    
+
+    # --------------------------------------------------
+    # 创建订单
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def cancel_order(db, order_id, reason, admin_id):
-        if order.status in (OrderStatus.cancelled, OrderStatus.completed):
-            return order
+    async def create(
+        db: AsyncSession,
+        data: VasOrderCreate,
+        product: VasProduct,
+        auth_user: VasUser,
+        redis_client: Redis,
+    ) -> VasOrder:
 
 
-        if order.status == OrderStatus.paid:
-            raise HTTPException(
-                400,
-                "Paid order must be refunded",
+        if not auth_user.email:
+            raise BizLogicError(
+                "Your account must be linked to an email address before you can place an order."
             )
             )
 
 
-        # 2️⃣ user_inputs 写入取消信息
-        user_inputs = order.user_inputs or {}
-        user_inputs["cancel"] = {
-            "reason": reason,
-            "by": "admin",
-            "admin_id": admin.user_id,
-            "at": datetime.utcnow().isoformat(),
-        }
-        order.user_inputs = user_inputs
+        order_id = f"ORD-{datetime.utcnow():%Y%m%d%H%M%S}-{uuid.uuid4().hex[:8]}"
 
 
-        # payment
-        for payment in order.payments:
-            if payment.status in (PaymentStatus.pending,):
-                payment.status = PaymentStatus.expired
+        order = VasOrder(
+            id=order_id,
+            **data.dict(),
+            product_name=product.title,
+            base_amount=product.price_amount,
+            base_currency=product.price_currency,
+            user_id=auth_user.id,
+        )
 
 
-        # task
-        for task in order.tasks:
-            task.status = TaskStatus.cancelled
+        db.add(order)
+        await db.commit()
+        await db.refresh(order)
 
 
         return order
         return order
-    
-    @staticmethod
-    def create(db: Session, data: VasOrderCreate, product: VasProduct, auth_user: VasUser, redis_client:Redis):
-        if not auth_user.email:
-            raise BizLogicError('Your account must be linked to an email address before you can place an order.')
-        order_id = f"ORD-{datetime.utcnow().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
-        rec = VasOrder(id=order_id, **data.dict())
-        rec.product_name = product.title
-        rec.base_amount = product.price_amount
-        rec.base_currency = product.price_currency
-        rec.user_id = auth_user.id
-        db.add(rec)
-        db.commit()
-        db.refresh(rec)
-        
-        print(f"📧 send order created notification email")
-        NotificationService.create(
-            redis_client=redis_client,
-            ntype="order create notify",
-            user_id=auth_user.id,
-            channels=["email"],
-            template_id="order_create_notify",
-            payload={
-                "order_id": rec.id
-            }
-        )
-        return rec
 
 
+    # --------------------------------------------------
+    # 获取订单
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def get(db: Session, id: str):
-        return db.query(VasOrder).filter_by(id=id).first()
+    async def get(db: AsyncSession, order_id: str) -> Optional[VasOrder]:
+        stmt = select(VasOrder).where(VasOrder.id == order_id)
+        return (await db.execute(stmt)).scalar_one_or_none()
 
 
+    # --------------------------------------------------
+    # 用户订单列表
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def list_by_user(db: Session, user_id: str, page: int=0, size: int=10, keyword: str=None):
-        query = db.query(VasOrder).filter_by(user_id=user_id)
-        query = apply_keyword_search(
-            query=query,
+    async def list_by_user(
+        db: AsyncSession,
+        user_id: str,
+        page: int = 0,
+        size: int = 10,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasOrder).where(VasOrder.user_id == user_id)
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasOrder,
             model=VasOrder,
             keyword=keyword,
             keyword=keyword,
-            fields=["id", "user_id", "product_name"]
-        )
-        query = query.order_by(
-            VasOrder.created_at.desc()
-        )
-        return paginate(query, page, size)
-    
+            fields=["id", "user_id", "product_name"],
+        ).order_by(VasOrder.created_at.desc())
+
+        return await paginate(db, stmt, page, size)
+
+    # --------------------------------------------------
+    # 管理员订单列表
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def list_all(db: Session, page: int=0, size: int=10, keyword: str=None):
-        query = db.query(VasOrder)
-        query = apply_keyword_search(
-            query=query,
+    async def list_all(
+        db: AsyncSession,
+        page: int = 0,
+        size: int = 10,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasOrder)
+        query = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasOrder,
             model=VasOrder,
             keyword=keyword,
             keyword=keyword,
-            fields=["id", "user_id", "user_name", "product_name", "user_inputs"]
-        )
-        query = query.order_by(
-            VasOrder.created_at.desc()
-        )
-        return paginate(query, page, size)
-    
+            fields=["id", "user_id", "user_name", "product_name"],
+        ).order_by(VasOrder.created_at.desc())
+
+        return await paginate(db, query, page, size)
+
+    # --------------------------------------------------
+    # 更新 user_inputs
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def patch_user_inputs(db: Session, order_id: str, payload: VasOrderPatchUserInputs):
-        order = db.query(VasOrder).filter_by(id=order_id).first()
+    async def patch_user_inputs(
+        db: AsyncSession,
+        order_id: str,
+        payload: VasOrderPatchUserInputs,
+    ) -> VasOrder:
+
+        stmt = select(VasOrder).where(VasOrder.id == order_id)
+        order = (await db.execute(stmt)).scalar_one_or_none()
+
         if not order:
         if not order:
             raise NotFoundError("Order not exist")
             raise NotFoundError("Order not exist")
+
         order.user_inputs = payload.user_inputs
         order.user_inputs = payload.user_inputs
-        db.commit()
-        db.refresh(order)
+        await db.commit()
+        await db.refresh(order)
+
         return order
         return order
-    
-    

+ 103 - 55
app/services/payment_provider_service.py

@@ -1,26 +1,34 @@
 # app/services/payment_provider.py
 # app/services/payment_provider.py
 
 
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
-from app.models.payment_provider import VasPaymentProvider
-from app.schemas.payment_provider import VasPaymentProviderCreate, VasPaymentProviderUpdate
+from typing import List, Optional
 
 
-class PaymentProviderSerivce:
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
 
 
-    def create(
-        db: Session,
-        data: VasPaymentProviderCreate
+from app.core.biz_exception import NotFoundError, BizLogicError
+from app.models.payment_provider import VasPaymentProvider
+from app.schemas.payment_provider import (
+    VasPaymentProviderCreate,
+    VasPaymentProviderUpdate,
+)
+
+
+class PaymentProviderService:
+    # --------------------------------------------------
+    # 创建支付提供商(防重复)
+    # --------------------------------------------------
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasPaymentProviderCreate,
     ) -> VasPaymentProvider:
     ) -> VasPaymentProvider:
-        # 防止重复注册
-        exists = (
-            db.query(VasPaymentProvider)
-            .filter(
-                VasPaymentProvider.name == data.name,
-                VasPaymentProvider.channel == data.channel,
-                VasPaymentProvider.currency == data.currency,
-            )
-            .first()
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.name == data.name,
+            VasPaymentProvider.channel == data.channel,
+            VasPaymentProvider.currency == data.currency,
         )
         )
+        exists = (await db.execute(stmt)).scalar_one_or_none()
 
 
         if exists:
         if exists:
             raise BizLogicError("Payment provider already exists")
             raise BizLogicError("Payment provider already exists")
@@ -35,64 +43,104 @@ class PaymentProviderSerivce:
         )
         )
 
 
         db.add(provider)
         db.add(provider)
-        db.commit()
-        db.refresh(provider)
+        await db.commit()
+        await db.refresh(provider)
         return provider
         return provider
 
 
-    def update(
-        db: Session,
+    # --------------------------------------------------
+    # 更新支付提供商(禁止修改 name/channel/currency)
+    # --------------------------------------------------
+    @staticmethod
+    async def update(
+        db: AsyncSession,
         provider_id: int,
         provider_id: int,
-        data: VasPaymentProviderUpdate
+        data: VasPaymentProviderUpdate,
     ) -> VasPaymentProvider:
     ) -> VasPaymentProvider:
-        provider = db.query(VasPaymentProvider).get(provider_id)
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.id == provider_id
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+
         if not provider:
         if not provider:
-            raise BizLogicError("Payment provider not found")
+            raise NotFoundError("Payment provider not found")
 
 
         update_data = data.dict(exclude_unset=True)
         update_data = data.dict(exclude_unset=True)
 
 
-        # 安全起见,禁止修改三元组
+        # 🚫 禁止修改三元组
         for forbidden in ("name", "channel", "currency"):
         for forbidden in ("name", "channel", "currency"):
             update_data.pop(forbidden, None)
             update_data.pop(forbidden, None)
 
 
         for key, value in update_data.items():
         for key, value in update_data.items():
             setattr(provider, key, value)
             setattr(provider, key, value)
 
 
-        db.commit()
-        db.refresh(provider)
+        await db.commit()
+        await db.refresh(provider)
         return provider
         return provider
-    
-    def delete(db: Session, id: int):
-        provider = db.query(VasPaymentProvider).filter_by(id=id).first()
+
+    # --------------------------------------------------
+    # 删除支付提供商
+    # --------------------------------------------------
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        provider_id: int,
+    ) -> bool:
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.id == provider_id
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+
         if not provider:
         if not provider:
             raise NotFoundError("Provider not exist")
             raise NotFoundError("Provider not exist")
-        db.delete(provider)
-        db.commit()
-    
-    def list_all(db: Session):
-        return db.query(VasPaymentProvider).all()
-
-    def list_enabled(
-        db: Session,
-        currency: str = None
-    ):
-        q = db.query(VasPaymentProvider).filter(
+
+        await db.delete(provider)
+        await db.commit()
+        return True
+
+    # --------------------------------------------------
+    # 所有支付提供商
+    # --------------------------------------------------
+    @staticmethod
+    async def list_all(
+        db: AsyncSession,
+    ) -> List[VasPaymentProvider]:
+
+        result = await db.execute(select(VasPaymentProvider))
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 可用的支付提供商(可按币种)
+    # --------------------------------------------------
+    @staticmethod
+    async def list_enabled(
+        db: AsyncSession,
+        currency: Optional[str] = None,
+    ) -> List[VasPaymentProvider]:
+
+        stmt = select(VasPaymentProvider).where(
             VasPaymentProvider.enabled == 1
             VasPaymentProvider.enabled == 1
         )
         )
 
 
         if currency:
         if currency:
-            q = q.filter(VasPaymentProvider.currency == currency)
-
-        return q.all()
-    
-    def get_by_name(
-        db: Session,
-        name: str
-    ):
-        q = db.query(VasPaymentProvider).filter(
-            VasPaymentProvider.enabled == 1
+            stmt = stmt.where(VasPaymentProvider.currency == currency)
+
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 根据 name 获取(只返回 enabled)
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_name(
+        db: AsyncSession,
+        name: str,
+    ) -> Optional[VasPaymentProvider]:
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.enabled == 1,
+            VasPaymentProvider.name == name,
         )
         )
 
 
-        if name:
-            q = q.filter(VasPaymentProvider.name == name)
-
-        return q.first()
+        return (await db.execute(stmt)).scalar_one_or_none()

+ 130 - 31
app/services/payment_qr_service.py

@@ -1,50 +1,149 @@
 # app/services/payment_qr_service.py
 # app/services/payment_qr_service.py
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from typing import List
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.core.biz_exception import NotFoundError
 from app.models.payment_provider import VasPaymentProvider
 from app.models.payment_provider import VasPaymentProvider
 from app.models.payment_qr import VasPaymentQR
 from app.models.payment_qr import VasPaymentQR
-from app.schemas.payment_qr import VasPaymentQrCreate, VasPaymentQrSetEnableIn
+from app.schemas.payment_qr import (
+    VasPaymentQrCreate,
+    VasPaymentQrSetEnableIn,
+)
+
 
 
 class PaymentQrService:
 class PaymentQrService:
 
 
-    def create(db: Session, data: VasPaymentQrCreate):
+    # --------------------------------------------------
+    # 创建支付二维码
+    # --------------------------------------------------
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasPaymentQrCreate,
+    ) -> VasPaymentQR:
+
         rec = VasPaymentQR(**data.dict())
         rec = VasPaymentQR(**data.dict())
         db.add(rec)
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
-    
-    def get_by_id(db: Session, id: int):
-        obj = db.query(VasPaymentQR).filter(VasPaymentQR.id == id).first()
+
+    # --------------------------------------------------
+    # 根据 ID 获取
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_id(
+        db: AsyncSession,
+        qr_id: int,
+    ) -> VasPaymentQR:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.id == qr_id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("QR not exist")
             raise NotFoundError("QR not exist")
+
         return obj
         return obj
-    
-    def set_enable(db: Session, id: int, payload: VasPaymentQrSetEnableIn):
-        obj = db.query(VasPaymentQR).filter(VasPaymentQR.id == id).first()
+
+    # --------------------------------------------------
+    # 启用 / 禁用 QR
+    # --------------------------------------------------
+    @staticmethod
+    async def set_enable(
+        db: AsyncSession,
+        qr_id: int,
+        payload: VasPaymentQrSetEnableIn,
+    ) -> VasPaymentQR:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.id == qr_id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("QR not exist")
             raise NotFoundError("QR not exist")
+
         obj.is_active = payload.is_active
         obj.is_active = payload.is_active
-        db.commit()
-        db.refresh(obj)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj
         return obj
-    
-    def delete(db: Session, id: int):
-        obj = db.query(VasPaymentQR).filter(VasPaymentQR.id == id).first()
+
+    # --------------------------------------------------
+    # 删除 QR
+    # --------------------------------------------------
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        qr_id: int,
+    ) -> bool:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.id == qr_id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("QR not exist")
             raise NotFoundError("QR not exist")
-        db.delete(obj)
-        db.commit()
-    
-    def get_by_devid(db: Session, devid: str):
-        return db.query(VasPaymentQR).filter(VasPaymentQR.devid == devid).all()
-
-    def get_by_provider(db: Session, provider: str):
-        return db.query(VasPaymentQR).filter(VasPaymentQR.provider == provider).all()
-    
-    def list_by_provider(db: Session, provider_id: int):
-        obj = db.query(VasPaymentProvider).filter(VasPaymentProvider.id==provider_id).first()
-        if not obj:
+
+        await db.delete(obj)
+        await db.commit()
+        return True
+
+    # --------------------------------------------------
+    # 根据设备 ID 查询
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_devid(
+        db: AsyncSession,
+        devid: str,
+    ) -> List[VasPaymentQR]:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.devid == devid
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 根据 provider 名称查询
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_provider(
+        db: AsyncSession,
+        provider: str,
+    ) -> List[VasPaymentQR]:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 根据 provider_id 查询 QR(安全校验)
+    # --------------------------------------------------
+    @staticmethod
+    async def list_by_provider(
+        db: AsyncSession,
+        provider_id: int,
+    ) -> List[VasPaymentQR]:
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.id == provider_id
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+
+        if not provider:
             raise NotFoundError("Provider not exist")
             raise NotFoundError("Provider not exist")
-        
-        return db.query(VasPaymentQR).filter(VasPaymentQR.provider==obj.name).all()
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider.name
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()

+ 384 - 78
app/services/payment_service.py

@@ -1,72 +1,358 @@
 # app/services/payment_service.py
 # app/services/payment_service.py
+
 import time
 import time
 import stripe
 import stripe
 import random
 import random
-from typing import List,Dict
+import uuid
+from typing import Dict, List, Optional
+from redis.asyncio import Redis
 from decimal import Decimal, ROUND_HALF_UP
 from decimal import Decimal, ROUND_HALF_UP
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
+from app.utils.pagination import paginate
+from app.core.biz_exception import NotFoundError, BizLogicError
+from app.models.user import VasUser
 from app.models.order import VasOrder
 from app.models.order import VasOrder
-from app.models.product import VasProduct
+from app.models.vas_task import VasTask
 from app.models.payment import VasPayment
 from app.models.payment import VasPayment
+from app.models.ticket import VasTicket
+from app.models.payment_event import VasPaymentEvent
+from app.models.product_routing import VasProductRouting
+from app.models.verification_token import VasVerificationToken
 from app.models.payment_provider import VasPaymentProvider
 from app.models.payment_provider import VasPaymentProvider
 from app.models.payment_qr import VasPaymentQR
 from app.models.payment_qr import VasPaymentQR
+from app.models.payment_confirmation import VasPaymentConfirmation
 from app.schemas.payment import VasPaymentCreate
 from app.schemas.payment import VasPaymentCreate
+from app.schemas.payment_confirmation import VasPaymentConfirmationCreate, VasPaymentConfirmationUpdate
+from app.services.notification_service import NotificationService
+
 
 
 
 
 class PaymentService:
 class PaymentService:
-    
+
+    # --------------------------------------------------
+    # 创建支付(统一入口)
+    # --------------------------------------------------
     @staticmethod
     @staticmethod
-    def create_payment(db: Session, payload: VasPaymentCreate, rate_table: Dict):
-        # ① 锁住订单,防止并发创建 payment
-        order = (
-            db.query(VasOrder)
-            .filter(VasOrder.id == payload.order_id)
+    async def create_payment(
+        db: AsyncSession,
+        payload: VasPaymentCreate,
+        rate_table: Dict,
+        redis_client: Redis
+    ) -> VasPayment:
+
+        # ① 锁住订单(防并发)
+        stmt = (
+            select(VasOrder)
+            .where(VasOrder.id == payload.order_id)
             .with_for_update()
             .with_for_update()
-            .one()
         )
         )
+        order = (await db.execute(stmt)).scalar_one_or_none()
+        if not order:
+            raise NotFoundError("Order not found")
 
 
-        # ② 是否已有进行中的 payment
-        active_payment = (
-            db.query(VasPayment)
-            .filter(
-                VasPayment.order_id == order.id,
-                VasPayment.status == "pending"
-            )
-            .first()
+        # ② 是否已有 pending payment(幂等)
+        stmt = select(VasPayment).where(
+            VasPayment.order_id == order.id,
+            VasPayment.status == "pending",
         )
         )
+        active_payment = (await db.execute(stmt)).scalar_one_or_none()
+
         if active_payment:
         if active_payment:
             if active_payment.provider == payload.provider:
             if active_payment.provider == payload.provider:
-                return active_payment  # 直接返回旧的,不报错(幂等性)
+                return active_payment
             else:
             else:
-                active_payment.status = 'failed' 
-                db.add(active_payment)
+                active_payment.status = "failed"
 
 
+        # ③ 根据 provider 创建
         if payload.provider in ("wechat", "alipay"):
         if payload.provider in ("wechat", "alipay"):
-            payment = PaymentService.create_offline_payment(db=db, order=order, provider_name=payload.provider, rate_table=rate_table)
-            db.commit()
+            payment = await PaymentService.create_offline_payment(
+                db=db,
+                order=order,
+                provider_name=payload.provider,
+                rate_table=rate_table,
+            )
+            await db.commit()
             return payment
             return payment
 
 
         if payload.provider == "stripe":
         if payload.provider == "stripe":
-            payment = PaymentService.create_stripe_payment(db=db, order=order, rate_table=rate_table)
-            db.commit()
+            payment = await PaymentService.create_stripe_payment(
+                db=db,
+                order=order,
+                rate_table=rate_table,
+            )
+            await db.commit()
             return payment
             return payment
 
 
         raise BizLogicError("Unsupported provider")
         raise BizLogicError("Unsupported provider")
     
     
     @staticmethod
     @staticmethod
-    def create_offline_payment(db, order, provider_name: str, rate_table: dict):
+    async def confirm_by_user(
+        db: AsyncSession,
+        payload: VasPaymentConfirmationCreate,
+        current_user: VasUser,
+        redis_client: Redis
+    ):
+        """
+        用户点击“我已支付”
+        """
+        # 1️⃣ 查询是否存在对应 payment 确认记录
+        result = await db.execute(
+            select(VasPaymentConfirmation)
+            .where(VasPaymentConfirmation.payment_id == payload.payment_id)
+            .where(VasPaymentConfirmation.user_id == current_user.id)
+        )
+        record = result.scalar_one_or_none()
+
+        if not record:
+            # 没有则创建一条 pending -> confirmed 记录
+            record = VasPaymentConfirmation(
+                payment_id=payload.payment_id,
+                amount=payload.amount,
+                currency=payload.currency,
+                random_offset=payload.random_offset,
+                user_id=current_user.id,
+                status="pending",
+                confirmed_at=payload.confirmed_at
+            )
+            db.add(record)
+            await db.commit()
+            await db.refresh(record)
+
+        # 2️⃣ 推送异步通知给管理员
+        # await NotificationService.create(
+        #     redis_client=redis_client,
+        #     ntype="payment_user_confirmed",
+        #     user_id=current_user.id,
+        #     channels=["wechat"],
+        #     template_id="payment_user_confirmed",
+        #     payload={
+        #         "payment_id": payload.payment_id,
+        #         "user_id": current_user.id,
+        #         "confirmed_at": record.confirmed_at.isoformat()
+        #     }
+        # )
+
+        return record
+    
+    @staticmethod
+    async def confirm_by_admin(
+        db: AsyncSession,
+        id: int,
+        payload: VasPaymentConfirmationUpdate,
+        current_user: VasUser
+    ):
+        """
+        管理员确认用户的支付
+        """
+        # 1️⃣ 查询对应确认记录
+        result = await db.execute(
+            select(VasPaymentConfirmation)
+            .where(VasPaymentConfirmation.id == id)
+        )
+        record = result.scalar_one_or_none()
+
+        if not record:
+            raise NotFoundError("Payment confirmation record not found")
+
+        # 3️⃣ 更新管理员确认状态
+        record.admin_id = current_user.id
+        record.admin_confirmed_at = datetime.utcnow()
+        record.status = 'confirmed'
+        await PaymentService._confirm_payment_action(db, record.payment_id)
+        
+        await db.commit()
+        await db.refresh(record)
+
+        return record
+    
+    @staticmethod
+    async def list_payment_confirmation(
+        db: AsyncSession,
+        keyword: Optional[str] = None,
+        page: int = 1,
+        size: int = 20,
+    ):
+        stmt = select(VasPaymentConfirmation)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
+            model=VasPaymentConfirmation,
+            keyword=keyword,
+            fields=["user_id"],
+        )
+
+        stmt = stmt.order_by(VasPaymentConfirmation.id.desc())
+
+        return await paginate(db, stmt, page, size)
+    
+    @staticmethod
+    async def _create_task_if_not_exists(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> List[VasTask]:
+
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == order.product_id,
+            VasProductRouting.is_active == 1,
+        )
+        result = await db.execute(stmt)
+        routings = result.scalars().all()
+
+        if not routings:
+            return []
+
+        created_tasks: List[VasTask] = []
+
+        for routing in routings:
+            exists_stmt = select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.routing_key == routing.routing_key,
+                VasTask.script_version == routing.script_version,
+            )
+            exists_result = await db.execute(exists_stmt)
+            exists = exists_result.scalar_one_or_none()
+
+            if exists:
+                continue
+
+            task = VasTask(
+                order_id=order.id,
+                routing_key=routing.routing_key,
+                script_version=routing.script_version,
+                priority=routing.priority,
+                status="pending",
+                user_inputs=order.user_inputs,
+                config=routing.config,
+                attempt_count=0,
+                notify_count=0,
+                expire_at=datetime.utcnow() + timedelta(days=60),
+                created_at=datetime.utcnow(),
+            )
+            db.add(task)
+            created_tasks.append(task)
+
+        return created_tasks
+    
+    @staticmethod
+    async def confirm_payment(
+        db: AsyncSession,
+        payment_id: int,
+        token: str
+    ):
+        # 校验验证码
+        stmt = select(VasVerificationToken).where(
+            VasVerificationToken.token == token,
+            VasVerificationToken.used == 0,
+        )
+        token_obj = (await db.execute(stmt)).scalar_one_or_none()
+        if not token_obj:
+            raise BizLogicError("Token invalid")
+
+        if token_obj.expire_at < datetime.utcnow():
+            raise BizLogicError("Token expired")
+        
+        payment = await PaymentService._confirm_payment_action(db, payment_id)
+        token_obj.used = 1
+        await db.commit()
+        return payment
+    
+    @staticmethod
+    async def _confirm_payment_action(db: AsyncSession, payment_id: int):
+        # ---------- 查找 payment ----------
+        pay_stmt = (
+            select(VasPayment)
+            .where(
+                VasPayment.id == payment_id,
+                VasPayment.status == "pending",
+            )
+            .order_by(VasPayment.created_at.desc())
+        )
+        pay_result = await db.execute(pay_stmt)
+        payment = pay_result.scalar_one_or_none()
+
+        if not payment:
+            raise BizLogicError("Payment not found")
+        
+        event = VasPaymentEvent(
+            provider=payment.provider,
+            event_type="payment_received",
+            title='confirm payment',
+            content='confirm payment by admin',
+            parsed_amount=payment.amount,
+            parsed_currency=payment.currency,
+            parsed_device='',
+            status="received",
+        )
+        db.add(event)
+        await db.commit()
+        await db.refresh(event)
+
+        if payment.status in ("succeeded", "late_paid"):
+            event.status = "duplicate"
+            event.matched_payment_id = payment.id
+            event.matched_order_id = payment.order_id
+            await db.commit()
+            return None
+
+        now = datetime.utcnow()
+        payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
+
+        payment.provider_payload = {
+            "title": "confirm by admin",
+            "received_at": now.isoformat(),
+        }
+
+        order_stmt = select(VasOrder).where(VasOrder.id == payment.order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
+        if order and order.status != "paid":
+            order.status = "paid"
+
+        await PaymentService._create_task_if_not_exists(db, order)
+
+        event.status = "applied"
+        event.matched_payment_id = payment.id
+        event.matched_order_id = payment.order_id
+
+        
+        await db.commit()
+        await db.refresh(payment)
+        
+        return payment
+
+    @staticmethod
+    async def create_offline_payment(
+        db: AsyncSession,
+        order: VasOrder,
+        provider_name: str,
+        rate_table: Dict,
+    ) -> VasPayment:
+
         payment = (
         payment = (
-            PaymentService._create_wechat_payment(db, order)
+            await PaymentService._create_wechat_payment(db, order)
             if provider_name == "wechat"
             if provider_name == "wechat"
-            else PaymentService._create_alipay_payment(db, order)
+            else await PaymentService._create_alipay_payment(db, order)
         )
         )
-        provider = db.query(VasPaymentProvider).filter(
+
+        stmt = select(VasPaymentProvider).where(
             VasPaymentProvider.enabled == 1,
             VasPaymentProvider.enabled == 1,
-            VasPaymentProvider.name == provider_name
-        ).first()
-        qrs = db.query(VasPaymentQR).filter(VasPaymentQR.provider == provider_name).all()
+            VasPaymentProvider.name == provider_name,
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+        if not provider:
+            raise BizLogicError("Payment provider not available")
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider_name,
+            VasPaymentQR.is_active == 1,
+        )
+        qrs = (await db.execute(stmt)).scalars().all()
         if not qrs:
         if not qrs:
             raise BizLogicError("No payment QR available")
             raise BizLogicError("No payment QR available")
 
 
@@ -79,25 +365,34 @@ class PaymentService:
         converted = (
         converted = (
             Decimal(payment.base_amount) * exchange_rate
             Decimal(payment.base_amount) * exchange_rate
         ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
         ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
-        
+
         max_discount = min(99, int(converted * Decimal("0.01")))
         max_discount = min(99, int(converted * Decimal("0.01")))
         discount = random.randint(1, max_discount) if max_discount >= 1 else 0
         discount = random.randint(1, max_discount) if max_discount >= 1 else 0
 
 
-        final_amount = int(converted) - discount
-
         payment.exchange_rate = exchange_rate
         payment.exchange_rate = exchange_rate
-        payment.amount = final_amount
+        payment.amount = int(converted) - discount
         payment.currency = provider.currency
         payment.currency = provider.currency
         payment.random_offset = discount
         payment.random_offset = discount
+
         return payment
         return payment
-    
+
     @staticmethod
     @staticmethod
-    def create_stripe_payment(db, order, rate_table: dict):
-        payment = PaymentService._create_stripe_payment(db, order)
-        provider = db.query(VasPaymentProvider).filter(
+    async def create_stripe_payment(
+        db: AsyncSession,
+        order: VasOrder,
+        rate_table: Dict,
+    ) -> VasPayment:
+
+        payment = await PaymentService._create_stripe_payment(db, order)
+
+        stmt = select(VasPaymentProvider).where(
             VasPaymentProvider.enabled == 1,
             VasPaymentProvider.enabled == 1,
-            VasPaymentProvider.name == 'stripe'
-        ).first()
+            VasPaymentProvider.name == "stripe",
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+        if not provider:
+            raise BizLogicError("Stripe provider not enabled")
+
         rate_key = f"{order.base_currency}->{provider.currency}".upper()
         rate_key = f"{order.base_currency}->{provider.currency}".upper()
         exchange_rate = Decimal(rate_table[rate_key])
         exchange_rate = Decimal(rate_table[rate_key])
 
 
@@ -113,8 +408,8 @@ class PaymentService:
         stripe_session = PaymentService.create_checkout_session(
         stripe_session = PaymentService.create_checkout_session(
             order=order,
             order=order,
             payment=payment,
             payment=payment,
-            success_url="https://yourdomain.com/pay/success",
-            cancel_url="https://yourdomain.com/pay/cancel"
+            success_url="https://visafly.top/dashboard",
+            cancel_url="https://visafly.top/dashboard",
         )
         )
 
 
         payment.payment_intent_id = stripe_session.id
         payment.payment_intent_id = stripe_session.id
@@ -124,22 +419,16 @@ class PaymentService:
 
 
     @staticmethod
     @staticmethod
     def create_checkout_session(
     def create_checkout_session(
-        order,
-        payment,
+        order: VasOrder,
+        payment: VasPayment,
         success_url: str,
         success_url: str,
         cancel_url: str,
         cancel_url: str,
     ):
     ):
-        """
-        order.base_amount  单位:cent
-        payment.amount     单位:cent
-        """
-        
-        expires_at = int(time.time()) + 30 * 60  # Stripe 专用
+        expires_at = int(time.time()) + 30 * 60
 
 
-        session = stripe.checkout.Session.create(
+        return stripe.checkout.Session.create(
             mode="payment",
             mode="payment",
             payment_method_types=["card"],
             payment_method_types=["card"],
-
             line_items=[
             line_items=[
                 {
                 {
                     "price_data": {
                     "price_data": {
@@ -152,23 +441,22 @@ class PaymentService:
                     "quantity": 1,
                     "quantity": 1,
                 }
                 }
             ],
             ],
-
             metadata={
             metadata={
                 "order_id": order.id,
                 "order_id": order.id,
                 "payment_id": payment.id,
                 "payment_id": payment.id,
                 "user_id": order.user_id,
                 "user_id": order.user_id,
             },
             },
-
-            success_url=success_url + "?session_id={CHECKOUT_SESSION_ID}",
+            success_url=success_url,
             cancel_url=cancel_url,
             cancel_url=cancel_url,
-
             expires_at=expires_at,
             expires_at=expires_at,
         )
         )
 
 
-        return session
-
     @staticmethod
     @staticmethod
-    def _create_wechat_payment(db: Session, order: VasOrder):
+    async def _create_wechat_payment(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> VasPayment:
+
         payment = VasPayment(
         payment = VasPayment(
             order_id=order.id,
             order_id=order.id,
             provider="wechat",
             provider="wechat",
@@ -183,11 +471,15 @@ class PaymentService:
             expire_at=datetime.utcnow() + timedelta(minutes=30),
             expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         )
         db.add(payment)
         db.add(payment)
-        db.flush()
+        await db.flush()
         return payment
         return payment
-    
+
     @staticmethod
     @staticmethod
-    def _create_alipay_payment(db: Session, order: VasOrder):
+    async def _create_alipay_payment(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> VasPayment:
+
         payment = VasPayment(
         payment = VasPayment(
             order_id=order.id,
             order_id=order.id,
             provider="alipay",
             provider="alipay",
@@ -202,11 +494,15 @@ class PaymentService:
             expire_at=datetime.utcnow() + timedelta(minutes=30),
             expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         )
         db.add(payment)
         db.add(payment)
-        db.flush()
+        await db.flush()
         return payment
         return payment
-    
+
     @staticmethod
     @staticmethod
-    def _create_stripe_payment(db: Session, order: VasOrder):
+    async def _create_stripe_payment(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> VasPayment:
+
         payment = VasPayment(
         payment = VasPayment(
             order_id=order.id,
             order_id=order.id,
             provider="stripe",
             provider="stripe",
@@ -221,16 +517,26 @@ class PaymentService:
             expire_at=datetime.utcnow() + timedelta(minutes=30),
             expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         )
         db.add(payment)
         db.add(payment)
-        db.flush()
+        await db.flush()
         return payment
         return payment
-    
+
     @staticmethod
     @staticmethod
-    def list_by_order(db: Session, order_id: str):
-        payments = (
-            db.query(VasPayment)
-            .filter(
-                VasPayment.order_id == order_id
-            )
-            .all()
+    async def list_by_order(
+        db: AsyncSession,
+        order_id: int,
+    ) -> List[VasPayment]:
+
+        stmt = select(VasPayment).where(
+            VasPayment.order_id == order_id
         )
         )
-        return payments
+        result = await db.execute(stmt)
+        return result.scalars().all()
+    
+    @staticmethod
+    async def get_by_id(
+        db: AsyncSession,
+        id: int,
+    ) -> VasPayment:
+        stmt = select(VasPayment).where(VasPayment.id == id)
+        return (await db.execute(stmt)).scalar_one_or_none()
+

+ 37 - 12
app/services/product_routing_service.py

@@ -1,23 +1,48 @@
 # app/services/product_routing_service.py
 # app/services/product_routing_service.py
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, delete
+
+from app.core.biz_exception import NotFoundError
 from app.models.product_routing import VasProductRouting
 from app.models.product_routing import VasProductRouting
 from app.schemas.product_routing import VasProductRoutingCreate
 from app.schemas.product_routing import VasProductRoutingCreate
 
 
+
 class ProductRoutingService:
 class ProductRoutingService:
-    def create(db: Session, data: VasProductRoutingCreate):
+
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasProductRoutingCreate,
+    ) -> VasProductRouting:
         rec = VasProductRouting(**data.dict())
         rec = VasProductRouting(**data.dict())
         db.add(rec)
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
 
 
-    def list_by_product(db: Session, product_id:int):
-        return db.query(VasProductRouting).filter_by(product_id=product_id).all()
-    
-    def delete(db: Session, id: int):
-        obj = db.query(VasProductRouting).filter_by(id=id).first()
+    @staticmethod
+    async def list_by_product(
+        db: AsyncSession,
+        product_id: int,
+    ):
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == product_id
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        id: int,
+    ):
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.id == id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
         if not obj:
             raise NotFoundError("Product routing not exist")
             raise NotFoundError("Product routing not exist")
-        db.delete(obj)
-        db.commit()
+
+        await db.delete(obj)
+        await db.commit()

+ 63 - 30
app/services/product_service.py

@@ -1,50 +1,83 @@
 # app/services/product_service.py
 # app/services/product_service.py
-from sqlalchemy.orm import Session
-from typing import Optional, List
-from app.utils.search import apply_keyword_search
+
+from typing import Optional
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import NotFoundError
 from app.models.product import VasProduct
 from app.models.product import VasProduct
-from app.schemas.product import VasProductCreate, VasProductUpdate, VasProductOut
+from app.schemas.product import VasProductCreate, VasProductUpdate
+
 
 
 class ProductService:
 class ProductService:
 
 
-    def create(db: Session, data: VasProductCreate):
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasProductCreate,
+    ) -> VasProduct:
         rec = VasProduct(**data.dict())
         rec = VasProduct(**data.dict())
         db.add(rec)
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
 
 
-    def get(db: Session, id:int):
-        obj = db.query(VasProduct).filter_by(id=id).first()
+    @staticmethod
+    async def get(
+        db: AsyncSession,
+        id: int,
+    ) -> VasProduct:
+        stmt = select(VasProduct).where(VasProduct.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
         if not obj:
-            raise NotFoundError('Product not exist')
+            raise NotFoundError("Product not exist")
         return obj
         return obj
 
 
-    def update(db: Session, id:int, data: VasProductUpdate):
-        rec = db.query(VasProduct).filter_by(id=id).first()
+    @staticmethod
+    async def update(
+        db: AsyncSession,
+        id: int,
+        data: VasProductUpdate,
+    ) -> VasProduct:
+        stmt = select(VasProduct).where(VasProduct.id == id)
+        rec = (await db.execute(stmt)).scalar_one_or_none()
         if not rec:
         if not rec:
-            raise NotFoundError('Product not exist')
-        for k,v in data.dict(exclude_unset=True).items():
-            setattr(rec,k,v)
-        db.commit()
-        db.refresh(rec)
+            raise NotFoundError("Product not exist")
+
+        for k, v in data.dict(exclude_unset=True).items():
+            setattr(rec, k, v)
+
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
-    
-    def list_product(db: Session, country: str = None, visa_type: str = None, page: int=0, size: int=10, keyword: str=None):    
-        query = db.query(VasProduct)
-        
+
+    @staticmethod
+    async def list_product(
+        db: AsyncSession,
+        country: str = None,
+        visa_type: str = None,
+        page: int = 0,
+        size: int = 10,
+        keyword: str = None,
+    ):
+        # ⚠️ paginate / apply_keyword_search 仍然基于 Query
+        # 如果你当前 paginate 是同步实现,这里保持与你原项目一致
+
+        stmt = select(VasProduct)
+
         if country:
         if country:
-            query = query.filter(VasProduct.country == country)
-            
+            stmt = stmt.where(VasProduct.country == country)
+
         if visa_type:
         if visa_type:
-            query = query.filter(VasProduct.visa_type == visa_type)
-        
-        query = apply_keyword_search(
-            query=query,
+            stmt = stmt.where(VasProduct.visa_type == visa_type)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasProduct,
             model=VasProduct,
             keyword=keyword,
             keyword=keyword,
-            fields=["title", "provider", "description"]
+            fields=["title", "provider", "description"],
         )
         )
-        return paginate(query, page, size)
+
+        return await paginate(db, stmt, page, size)

+ 54 - 22
app/services/schema_service.py

@@ -1,41 +1,73 @@
 # app/services/schema_service.py
 # app/services/schema_service.py
-from sqlalchemy.orm import Session
-from typing import Optional
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from typing import List
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.core.biz_exception import NotFoundError
 from app.models.schema import VasSchema
 from app.models.schema import VasSchema
 from app.schemas.schema import VasSchemaCreate, VasSchemaUpdate
 from app.schemas.schema import VasSchemaCreate, VasSchemaUpdate
 
 
+
 class SchemaService:
 class SchemaService:
 
 
-    def create(db: Session, data: VasSchemaCreate):
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasSchemaCreate,
+    ) -> VasSchema:
         rec = VasSchema(**data.dict())
         rec = VasSchema(**data.dict())
         db.add(rec)
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
 
 
-    def get(db: Session, id: int):
-        obj = db.query(VasSchema).filter_by(id=id).first()
+    @staticmethod
+    async def get(
+        db: AsyncSession,
+        id: int,
+    ) -> VasSchema:
+        stmt = select(VasSchema).where(VasSchema.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
         if not obj:
-            raise NotFoundError('Schema not exist')
+            raise NotFoundError("Schema not exist")
         return obj
         return obj
 
 
-    def update(db: Session, id: int, data: VasSchemaUpdate):
-        obj = db.query(VasSchema).filter_by(id=id).first()
+    @staticmethod
+    async def update(
+        db: AsyncSession,
+        id: int,
+        data: VasSchemaUpdate,
+    ) -> VasSchema:
+        stmt = select(VasSchema).where(VasSchema.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
         if not obj:
-            raise NotFoundError('Schema not exist')
-        for k,v in data.dict(exclude_unset=True).items():
+            raise NotFoundError("Schema not exist")
+
+        for k, v in data.dict(exclude_unset=True).items():
             setattr(obj, k, v)
             setattr(obj, k, v)
-        db.commit()
-        db.refresh(obj)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj
         return obj
 
 
-    def delete(db: Session, id:int):
-        obj = db.query(VasSchema).filter_by(id=id).first()
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        id: int,
+    ) -> None:
+        stmt = select(VasSchema).where(VasSchema.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
         if not obj:
-            raise NotFoundError('Schema not exist')
-        db.delete(obj)
-        db.commit()
+            raise NotFoundError("Schema not exist")
+
+        await db.delete(obj)
+        await db.commit()
 
 
-    def list_all(db: Session):
-        return db.query(VasSchema).all()
+    @staticmethod
+    async def list_all(
+        db: AsyncSession,
+    ) -> List[VasSchema]:
+        stmt = select(VasSchema)
+        result = await db.execute(stmt)
+        return result.scalars().all()

+ 86 - 45
app/services/seaweedfs_service.py

@@ -1,71 +1,112 @@
-import requests
+# app/services/seaweedfs_service.py
+
+import httpx
 from fastapi import UploadFile
 from fastapi import UploadFile
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import BizLogicError
 from app.core.logger import logger
 from app.core.logger import logger
 
 
 
 
 class SeaweedFSService:
 class SeaweedFSService:
-    MASTER_URL = "http://127.0.0.1:9333"  # 你的 SeaweedFS master 地址
+    MASTER_URL = "http://127.0.0.1:9333"  # SeaweedFS master 地址
+    DOWNLOAD_GATEWAY = "http://45.137.220.138:8888/api/resource/download_file"
 
 
     @classmethod
     @classmethod
-    def upload(cls, file: UploadFile):
-        """上传文件到 SeaweedFS"""
+    async def upload(cls, file: UploadFile):
+        """上传文件到 SeaweedFS(异步)"""
         try:
         try:
-            # 1️⃣ 获取可上传的 volume 地址
-            assign_resp = requests.get(f"{cls.MASTER_URL}/dir/assign", timeout=5)
-            assign_data = assign_resp.json()
-            fid = assign_data["fid"]
-            public_url = assign_data["publicUrl"]
-
-            # 2️⃣ 上传文件数据
-            upload_url = f"http://{public_url}/{fid}"
-            
-            download_url = f"http://45.137.220.138:8888/api/resource/download_file?fid={fid}"
-            files = {"file": (file.filename, file.file, file.content_type)}
-            upload_resp = requests.post(upload_url, files=files, timeout=10)
-
-            if upload_resp.status_code == 201:
-                return {"fid": fid, "url": download_url}
-            else:
+            async with httpx.AsyncClient(timeout=10) as client:
+                # 1️⃣ 获取 volume
+                assign_resp = await client.get(f"{cls.MASTER_URL}/dir/assign")
+                assign_resp.raise_for_status()
+                assign_data = assign_resp.json()
+
+                fid = assign_data["fid"]
+                public_url = assign_data["publicUrl"]
+
+                upload_url = f"http://{public_url}/{fid}"
+                download_url = f"{cls.DOWNLOAD_GATEWAY}?fid={fid}"
+
+                # 2️⃣ 上传文件
+                files = {
+                    "file": (
+                        file.filename,
+                        await file.read(),
+                        file.content_type,
+                    )
+                }
+
+                upload_resp = await client.post(upload_url, files=files)
+
+                if upload_resp.status_code == 201:
+                    return {
+                        "fid": fid,
+                        "url": download_url,
+                    }
+
                 raise BizLogicError(f"file upload error: {upload_resp.text}")
                 raise BizLogicError(f"file upload error: {upload_resp.text}")
+
         except Exception as e:
         except Exception as e:
+            logger.exception("SeaweedFS upload failed")
             raise BizLogicError(f"file upload exception: {e}")
             raise BizLogicError(f"file upload exception: {e}")
 
 
     @classmethod
     @classmethod
-    def get(cls, fid: str):
-        """根据 fid 读取文件"""
+    async def get(cls, fid: str):
+        """根据 fid 读取文件(异步)"""
         try:
         try:
-            resp = requests.get(f"{cls.MASTER_URL}/dir/lookup?volumeId={fid.split(',')[0]}", timeout=5)
-            data = resp.json()
-            if not data.get("locations"):
-                return None
+            volume_id = fid.split(",")[0]
+
+            async with httpx.AsyncClient(timeout=10) as client:
+                lookup_resp = await client.get(
+                    f"{cls.MASTER_URL}/dir/lookup",
+                    params={"volumeId": volume_id},
+                )
+                lookup_resp.raise_for_status()
+                data = lookup_resp.json()
+
+                if not data.get("locations"):
+                    return None
+
+                public_url = data["locations"][0]["publicUrl"]
+                file_url = f"http://{public_url}/{fid}"
 
 
-            public_url = data["locations"][0]["publicUrl"]
-            file_url = f"http://{public_url}/{fid}"
+                file_resp = await client.get(file_url)
+                if file_resp.status_code == 200:
+                    return (
+                        file_resp.content,
+                        file_resp.headers.get(
+                            "Content-Type", "application/octet-stream"
+                        ),
+                    )
 
 
-            file_resp = requests.get(file_url, timeout=10)
-            if file_resp.status_code == 200:
-                return file_resp.content, file_resp.headers.get("Content-Type", "application/octet-stream")
-            else:
                 return None
                 return None
+
         except Exception as e:
         except Exception as e:
-            logger.exception(f"SeaweedFS 读取异常, 原因={e}")
+            logger.exception(f"SeaweedFS get failed, reason={e}")
             return None
             return None
 
 
     @classmethod
     @classmethod
-    def delete(cls, fid: str):
-        """删除文件"""
+    async def delete(cls, fid: str) -> bool:
+        """删除文件(异步)"""
         try:
         try:
-            resp = requests.get(f"{cls.MASTER_URL}/dir/lookup?volumeId={fid.split(',')[0]}", timeout=5)
-            data = resp.json()
-            if not data.get("locations"):
-                return False
+            volume_id = fid.split(",")[0]
+
+            async with httpx.AsyncClient(timeout=10) as client:
+                lookup_resp = await client.get(
+                    f"{cls.MASTER_URL}/dir/lookup",
+                    params={"volumeId": volume_id},
+                )
+                lookup_resp.raise_for_status()
+                data = lookup_resp.json()
+
+                if not data.get("locations"):
+                    return False
+
+                public_url = data["locations"][0]["publicUrl"]
+                delete_url = f"http://{public_url}/{fid}"
 
 
-            public_url = data["locations"][0]["publicUrl"]
-            delete_url = f"http://{public_url}/{fid}"
+                del_resp = await client.delete(delete_url)
+                return del_resp.status_code == 202
 
 
-            del_resp = requests.delete(delete_url, timeout=10)
-            return del_resp.status_code == 202
         except Exception as e:
         except Exception as e:
-            logger.exception(f"SeaweedFS 删除异常, 原因={e}")
+            logger.exception(f"SeaweedFS delete failed, reason={e}")
             return False
             return False

+ 26 - 20
app/services/session_service.py

@@ -1,10 +1,10 @@
 # app/services/session_service.py
 # app/services/session_service.py
 
 
-from datetime import datetime, timedelta
-import uuid
+from datetime import datetime
+from typing import Optional
 
 
-from sqlalchemy.orm import Session as DBSession
-from sqlalchemy import delete
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, delete
 
 
 from app.models.session import VasSession
 from app.models.session import VasSession
 from app.models.user import VasUser
 from app.models.user import VasUser
@@ -12,28 +12,34 @@ from app.models.user import VasUser
 
 
 class SessionService:
 class SessionService:
 
 
-    # ============================
-    # token → user(鉴权用)
-    # ============================
     @staticmethod
     @staticmethod
-    def get_user_by_token(db: DBSession, session_id: str) -> VasUser:
-        session_obj = db.query(VasSession).filter(VasSession.id == session_id).first()
+    async def get_user_by_token(
+        db: AsyncSession,
+        session_id: str
+    ) -> Optional[VasUser]:
+        result = await db.execute(
+            select(VasSession).where(VasSession.id == session_id)
+        )
+        session_obj = result.scalar_one_or_none()
+
         if not session_obj:
         if not session_obj:
             return None
             return None
 
 
-        # session 是否过期
         if session_obj.expire_at < datetime.utcnow():
         if session_obj.expire_at < datetime.utcnow():
-            # 自动删除过期 session
-            SessionService.delete_session(db, session_id)
+            await SessionService.delete_session(db, session_id)
             return None
             return None
 
 
-        user = db.query(VasUser).filter(VasUser.id == session_obj.user_id).first()
-        return user
+        result = await db.execute(
+            select(VasUser).where(VasUser.id == session_obj.user_id)
+        )
+        return result.scalar_one_or_none()
 
 
-    # ============================
-    # 删除 session(登出)
-    # ============================
     @staticmethod
     @staticmethod
-    def delete_session(db: DBSession, session_id: str):
-        db.query(VasSession).filter(VasSession.id == session_id).delete()
-        db.commit()
+    async def delete_session(
+        db: AsyncSession,
+        session_id: str
+    ) -> None:
+        await db.execute(
+            delete(VasSession).where(VasSession.id == session_id)
+        )
+        await db.commit()

+ 45 - 18
app/services/short_url_service.py

@@ -1,40 +1,67 @@
+# app/services/short_url_service.py
+
 import string
 import string
 import random
 import random
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+from sqlalchemy.exc import IntegrityError
+
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.short_url import ShortUrl
 from app.models.short_url import ShortUrl
 
 
 
 
 class ShortUrlService:
 class ShortUrlService:
+
     @staticmethod
     @staticmethod
     def generate_short_key(length: int = 8) -> str:
     def generate_short_key(length: int = 8) -> str:
-        """生成随机短 Key(字母+数字组成)"""
+        """生成随机短 Key(字母 + 数字)"""
         chars = string.ascii_letters + string.digits
         chars = string.ascii_letters + string.digits
         return ''.join(random.choices(chars, k=length))
         return ''.join(random.choices(chars, k=length))
 
 
     @staticmethod
     @staticmethod
-    def create_short_url(db: Session, long_url: str) -> ShortUrl:
-        """创建短链接"""
-        # 检查是否已经存在相同的长链接
-        existing = db.query(ShortUrl).filter(ShortUrl.long_url == long_url).first()
+    async def create_short_url(
+        db: AsyncSession,
+        long_url: str
+    ) -> ShortUrl:
+        """创建短链接(异步 + 并发安全)"""
+
+        # 1️⃣ 是否已有相同 long_url
+        stmt = select(ShortUrl).where(ShortUrl.long_url == long_url)
+        existing = await db.scalar(stmt)
         if existing:
         if existing:
             raise BizLogicError("Short url already exist")
             raise BizLogicError("Short url already exist")
 
 
-        # 生成唯一 short_key
-        short_key = ShortUrlService.generate_short_key()
-        while db.query(ShortUrl).filter(ShortUrl.short_key == short_key).first():
+        # 2️⃣ 生成 short_key(循环直到成功)
+        while True:
             short_key = ShortUrlService.generate_short_key()
             short_key = ShortUrlService.generate_short_key()
 
 
-        db_obj = ShortUrl(short_key=short_key, long_url=long_url)
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
-        return db_obj
+            db_obj = ShortUrl(
+                short_key=short_key,
+                long_url=long_url
+            )
+            db.add(db_obj)
+
+            try:
+                await db.commit()
+                await db.refresh(db_obj)
+                return db_obj
+
+            except IntegrityError:
+                # short_key 冲突,重试
+                await db.rollback()
+                continue
 
 
     @staticmethod
     @staticmethod
-    def get_long_url(db: Session, short_key: str) -> str:
-        """通过短 key 获取原始长链接"""
-        record = db.query(ShortUrl).filter(ShortUrl.short_key == short_key).first()
+    async def get_long_url(
+        db: AsyncSession,
+        short_key: str
+    ) -> str:
+        """通过短 key 获取原始长链接(异步)"""
+
+        stmt = select(ShortUrl).where(ShortUrl.short_key == short_key)
+        record = await db.scalar(stmt)
+
         if not record:
         if not record:
             raise NotFoundError("Short url not found")
             raise NotFoundError("Short url not found")
+
         return record.long_url
         return record.long_url

+ 30 - 7
app/services/slot_snapshot_service.py

@@ -1,17 +1,40 @@
 # app/services/slot_snapshot_service.py
 # app/services/slot_snapshot_service.py
-from sqlalchemy.orm import Session
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
 from app.models.slot_snapshot import VasSlotSnapshot
 from app.models.slot_snapshot import VasSlotSnapshot
 from app.schemas.slot_snapshot import SlotSnapshotCreate
 from app.schemas.slot_snapshot import SlotSnapshotCreate
-from datetime import datetime
+
 
 
 class SlotSnapshotService:
 class SlotSnapshotService:
 
 
-    def create(db: Session, data: SlotSnapshotCreate):
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: SlotSnapshotCreate
+    ) -> VasSlotSnapshot:
         rec = VasSlotSnapshot(**data.dict())
         rec = VasSlotSnapshot(**data.dict())
         db.add(rec)
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
 
 
-    def latest_for(db: Session, country:str, city:str, visa_type:str):
-        return db.query(VasSlotSnapshot).filter_by(country=country, city=city, visa_type=visa_type).order_by(VasSlotSnapshot.snapshot_at.desc()).first()
+    @staticmethod
+    async def latest_for(
+        db: AsyncSession,
+        country: str,
+        city: str,
+        visa_type: str
+    ) -> VasSlotSnapshot:
+        stmt = (
+            select(VasSlotSnapshot)
+            .where(
+                VasSlotSnapshot.country == country,
+                VasSlotSnapshot.city == city,
+                VasSlotSnapshot.visa_type == visa_type,
+            )
+            .order_by(VasSlotSnapshot.snapshot_at.desc())
+            .limit(1)
+        )
+
+        return await db.scalar(stmt)

+ 52 - 22
app/services/sms_service.py

@@ -1,56 +1,86 @@
+# app/services/sms_service.py
+
 import json
 import json
-import time
-import requests
 from typing import List
 from typing import List
+from redis.asyncio import Redis
 from app.schemas.sms import ShortMessageDetail
 from app.schemas.sms import ShortMessageDetail
 
 
 
 
-def save_short_message(redis_client, phone: str, message: str, received_at: str, max_ttl: int) -> ShortMessageDetail:
+async def save_short_message(
+    redis_client: Redis,
+    phone: str,
+    message: str,
+    received_at: str,
+    max_ttl: int
+) -> ShortMessageDetail:
     """
     """
-    将短信保存到 Redis。
-    键格式: sms:{phone}
-    值为 JSON 数组(保存最近多条短信)
+    将短信保存到 Redis(异步版)
+    key: sms:{phone}
+    value: JSON 数组(最多保留最近 20 条
     """
     """
     key = f"sms:{phone}"
     key = f"sms:{phone}"
 
 
-    # 取出已有数据
-    existing_data = redis_client.get(key)
+    # 1️⃣ 读取已有短信
+    existing_data = await redis_client.get(key)
     if existing_data:
     if existing_data:
         messages = json.loads(existing_data)
         messages = json.loads(existing_data)
     else:
     else:
         messages = []
         messages = []
 
 
-    # 添加新短信
-    new_msg = ShortMessageDetail(phone=phone, message=message, received_at=received_at)
+    # 2️⃣ 添加新短信
+    new_msg = ShortMessageDetail(
+        phone=phone,
+        message=message,
+        received_at=received_at
+    )
     messages.append(new_msg.dict())
     messages.append(new_msg.dict())
 
 
-    # 最多保留最近 20 条(可调整)
+    # 3️⃣ 保留最近 20 条
     messages = messages[-20:]
     messages = messages[-20:]
 
 
-    # 保存回 Redis
-    redis_client.setex(key, max_ttl, json.dumps(messages))
+    # 4️⃣ 写回 Redis(重置 TTL)
+    await redis_client.setex(
+        key,
+        max_ttl,
+        json.dumps(messages)
+    )
 
 
     return new_msg
     return new_msg
 
 
 
 
-def query_short_message(redis_client, phone: str, keyword: str, sent_at: str) -> List[ShortMessageDetail]:
+async def query_short_message(
+    redis_client: Redis,
+    phone: str,
+    keyword: str = None,
+    sent_at: str = None
+) -> List[ShortMessageDetail]:
     """
     """
-    从 Redis 查询短信。
-    支持按关键字和时间过滤。
+    从 Redis 查询短信(异步版)
+    支持关键字 / 时间过滤
     """
     """
     key = f"sms:{phone}"
     key = f"sms:{phone}"
-    existing_data = redis_client.get(key)
+
+    existing_data = await redis_client.get(key)
     if not existing_data:
     if not existing_data:
         return []
         return []
 
 
-    messages = [ShortMessageDetail(**m) for m in json.loads(existing_data)]
+    messages = [
+        ShortMessageDetail(**m)
+        for m in json.loads(existing_data)
+    ]
 
 
     # 关键字过滤
     # 关键字过滤
     if keyword:
     if keyword:
-        messages = [m for m in messages if keyword in m.message]
+        messages = [
+            m for m in messages
+            if keyword in m.message
+        ]
 
 
-    # 时间过滤(如果需要更复杂可转为时间戳比较)
+    # 时间过滤(字符串比较,ISO8601 安全
     if sent_at:
     if sent_at:
-        messages = [m for m in messages if m.received_at >= sent_at]
+        messages = [
+            m for m in messages
+            if m.received_at >= sent_at
+        ]
 
 
-    return messages
+    return messages

+ 157 - 90
app/services/statistics_service.py

@@ -1,5 +1,7 @@
-from sqlalchemy.orm import Session
-from sqlalchemy import func, desc, and_, case
+# app/services/statistics_service.py
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import func, desc, case, select
 from datetime import datetime, timedelta, date
 from datetime import datetime, timedelta, date
 from typing import Dict, Any
 from typing import Dict, Any
 
 
@@ -9,8 +11,11 @@ from app.models.vas_task import VasTask
 from app.models.user import VasUser
 from app.models.user import VasUser
 from app.models.product import VasProduct
 from app.models.product import VasProduct
 
 
-# 静态汇率配置 (基准: CNY)
-# 实际生产环境建议从数据库或缓存获取实时汇率
+
+# ======================
+# 汇率 & 货币符号
+# ======================
+
 EXCHANGE_RATES = {
 EXCHANGE_RATES = {
     "CNY": 1.0,
     "CNY": 1.0,
     "USD": 7.25,
     "USD": 7.25,
@@ -24,161 +29,223 @@ CURRENCY_SYMBOLS = {
     "CNY": "¥", "USD": "$", "EUR": "€", "GBP": "£", "HKD": "HK$", "JPY": "¥"
     "CNY": "¥", "USD": "$", "EUR": "€", "GBP": "£", "HKD": "HK$", "JPY": "¥"
 }
 }
 
 
+
 class StatisticsService:
 class StatisticsService:
+
+    # ======================
+    # 工具方法
+    # ======================
+
     @staticmethod
     @staticmethod
     def _convert_to_cny(amount: any, currency: str) -> int:
     def _convert_to_cny(amount: any, currency: str) -> int:
-        """
-        辅助函数:将金额(分)转换为人民币(分)
-        """
+        """金额(分) → CNY(分)"""
         if not amount:
         if not amount:
             return 0
             return 0
-        rate = EXCHANGE_RATES.get(currency, 1.0) # 未知货币默认按 1:1 处理
-        
-        # 修复点:将 amount 转换为 float,解决 Decimal * float 报错的问题
+        rate = EXCHANGE_RATES.get(currency, 1.0)
         return int(float(amount) * rate)
         return int(float(amount) * rate)
 
 
+    # ======================
+    # 核心接口
+    # ======================
+
     @staticmethod
     @staticmethod
-    def overview(db: Session) -> Dict[str, Any]:
+    async def overview(db: AsyncSession) -> Dict[str, Any]:
         """
         """
-        获取后台概览数据 (统一换算为 CNY 统计)
+        后台统计概览(Async 版)
         """
         """
-        # --- 1. 核心指标卡片 ---
-        
-        # 1.1 总营收 (按币种分组求和,再换算)
-        revenue_groups = db.query(
-            VasOrder.base_currency,
-            func.sum(VasOrder.base_amount)
-        ).filter(
-            VasOrder.status.in_(['paid', 'completed', 'succeeded'])
-        ).group_by(VasOrder.base_currency).all()
-
-        total_revenue_cny = 0
-        for currency, amount in revenue_groups:
-            total_revenue_cny += StatisticsService._convert_to_cny(amount, currency)
+
+        # --------------------------------------------------
+        # 1. 核心指标
+        # --------------------------------------------------
+
+        # 1.1 总营收(按币种分组)
+        revenue_stmt = (
+            select(
+                VasOrder.base_currency,
+                func.sum(VasOrder.base_amount)
+            )
+            .where(VasOrder.status.in_(["paid", "completed", "succeeded"]))
+            .group_by(VasOrder.base_currency)
+        )
+
+        revenue_rows = (await db.execute(revenue_stmt)).all()
+
+        total_revenue_cny = sum(
+            StatisticsService._convert_to_cny(amount, currency)
+            for currency, amount in revenue_rows
+        )
 
 
         # 1.2 活跃订单数
         # 1.2 活跃订单数
-        total_orders = db.query(func.count(VasOrder.id))\
-            .filter(VasOrder.status != 'closed')\
-            .scalar() or 0
+        total_orders = (
+            await db.scalar(
+                select(func.count(VasOrder.id))
+                .where(VasOrder.status != "closed")
+            )
+        ) or 0
 
 
         # 1.3 活跃用户数
         # 1.3 活跃用户数
-        active_users = db.query(func.count(VasUser.id)).scalar() or 0
+        active_users = (
+            await db.scalar(select(func.count(VasUser.id)))
+        ) or 0
 
 
         # 1.4 待处理工单
         # 1.4 待处理工单
-        pending_tickets = db.query(func.count(VasTicket.id))\
-            .filter(VasTicket.status.in_(['pending', 'info_required']))\
-            .scalar() or 0
+        pending_tickets = (
+            await db.scalar(
+                select(func.count(VasTicket.id))
+                .where(VasTicket.status.in_(["pending", "info_required"]))
+            )
+        ) or 0
 
 
         # 1.5 任务成功率
         # 1.5 任务成功率
-        task_counts = db.query(
-            func.count(VasTask.id).label('total'),
-            func.sum(case((VasTask.status == 'completed', 1), else_=0)).label('success')
-        ).first()
-        
+        task_stmt = select(
+            func.count(VasTask.id).label("total"),
+            func.sum(
+                case((VasTask.status == "completed", 1), else_=0)
+            ).label("success")
+        )
+
+        task_counts = (await db.execute(task_stmt)).first()
+
         success_rate_str = "0%"
         success_rate_str = "0%"
         if task_counts and task_counts.total > 0:
         if task_counts and task_counts.total > 0:
             rate = (task_counts.success / task_counts.total) * 100
             rate = (task_counts.success / task_counts.total) * 100
             success_rate_str = f"{rate:.1f}%"
             success_rate_str = f"{rate:.1f}%"
 
 
-        # --- 2. 营收趋势图 (Last 7 Days) ---
-        
+        # --------------------------------------------------
+        # 2. 最近 7 天营收趋势
+        # --------------------------------------------------
+
         revenue_trend = []
         revenue_trend = []
         today = date.today()
         today = date.today()
-        
+
         for i in range(6, -1, -1):
         for i in range(6, -1, -1):
             target_date = today - timedelta(days=i)
             target_date = today - timedelta(days=i)
             start_dt = datetime.combine(target_date, datetime.min.time())
             start_dt = datetime.combine(target_date, datetime.min.time())
             end_dt = datetime.combine(target_date, datetime.max.time())
             end_dt = datetime.combine(target_date, datetime.max.time())
-            
-            # 按币种分组查询当天的营收
-            daily_groups = db.query(
-                VasOrder.base_currency,
-                func.sum(VasOrder.base_amount).label('amount'),
-                func.count(VasOrder.id).label('orders')
-            ).filter(
-                VasOrder.created_at >= start_dt,
-                VasOrder.created_at <= end_dt,
-                VasOrder.status.in_(['paid', 'completed', 'succeeded'])
-            ).group_by(VasOrder.base_currency).all()
+
+            daily_stmt = (
+                select(
+                    VasOrder.base_currency,
+                    func.sum(VasOrder.base_amount).label("amount"),
+                    func.count(VasOrder.id).label("orders")
+                )
+                .where(
+                    VasOrder.created_at >= start_dt,
+                    VasOrder.created_at <= end_dt,
+                    VasOrder.status.in_(["paid", "completed", "succeeded"])
+                )
+                .group_by(VasOrder.base_currency)
+            )
+
+            daily_rows = (await db.execute(daily_stmt)).all()
 
 
             daily_amount_cny = 0
             daily_amount_cny = 0
             daily_order_count = 0
             daily_order_count = 0
-            
-            for curr, amt, cnt in daily_groups:
+
+            for curr, amt, cnt in daily_rows:
                 daily_amount_cny += StatisticsService._convert_to_cny(amt, curr)
                 daily_amount_cny += StatisticsService._convert_to_cny(amt, curr)
                 daily_order_count += cnt
                 daily_order_count += cnt
 
 
             revenue_trend.append({
             revenue_trend.append({
                 "date": target_date.strftime("%m-%d"),
                 "date": target_date.strftime("%m-%d"),
-                "amount": float(daily_amount_cny) / 100.0, # 转为元 (浮点数)
+                "amount": daily_amount_cny / 100.0,
                 "orders": daily_order_count
                 "orders": daily_order_count
             })
             })
 
 
-        # --- 3. 商品销量分布 ---
-        
-        product_stats = db.query(
-            VasProduct.title,
-            func.count(VasOrder.id).label('count')
-        ).join(VasOrder, VasOrder.product_id == VasProduct.id)\
-         .filter(VasOrder.status.in_(['paid', 'completed', 'succeeded']))\
-         .group_by(VasProduct.title)\
-         .order_by(desc('count'))\
-         .limit(5).all()
-
-        product_dist = [{"name": p.title, "value": p.count} for p in product_stats]
-
-        # --- 4. 最新动态 ---
-        
+        # --------------------------------------------------
+        # 3. 商品销量分布(Top 5)
+        # --------------------------------------------------
+
+        product_stmt = (
+            select(
+                VasProduct.title,
+                func.count(VasOrder.id).label("count")
+            )
+            .join(VasOrder, VasOrder.product_id == VasProduct.id)
+            .where(VasOrder.status.in_(["paid", "completed", "succeeded"]))
+            .group_by(VasProduct.title)
+            .order_by(desc("count"))
+            .limit(5)
+        )
+
+        product_rows = (await db.execute(product_stmt)).all()
+
+        product_dist = [
+            {"name": title, "value": count}
+            for title, count in product_rows
+        ]
+
+        # --------------------------------------------------
+        # 4. 最新动态
+        # --------------------------------------------------
+
         activities = []
         activities = []
-        
-        # 订单动态
-        recent_orders = db.query(VasOrder).order_by(desc(VasOrder.created_at)).limit(5).all()
+
+        # 最近订单
+        order_stmt = (
+            select(VasOrder)
+            .order_by(desc(VasOrder.created_at))
+            .limit(5)
+        )
+        recent_orders = (await db.execute(order_stmt)).scalars().all()
+
         for o in recent_orders:
         for o in recent_orders:
             symbol = CURRENCY_SYMBOLS.get(o.base_currency, o.base_currency)
             symbol = CURRENCY_SYMBOLS.get(o.base_currency, o.base_currency)
             amt_display = f"{symbol}{o.base_amount / 100}"
             amt_display = f"{symbol}{o.base_amount / 100}"
-            
+
             activities.append({
             activities.append({
                 "id": f"order_{o.id}",
                 "id": f"order_{o.id}",
                 "text": f"用户下单: {o.product_name or '未知商品'} ({amt_display})",
                 "text": f"用户下单: {o.product_name or '未知商品'} ({amt_display})",
                 "time": o.created_at,
                 "time": o.created_at,
-                "type": "order" if o.status == 'pending' else "money"
+                "type": "order" if o.status == "pending" else "money"
             })
             })
 
 
-        # 工单动态
-        recent_tickets = db.query(VasTicket).order_by(desc(VasTicket.created_at)).limit(5).all()
+        # 最近工单
+        ticket_stmt = (
+            select(VasTicket)
+            .order_by(desc(VasTicket.created_at))
+            .limit(5)
+        )
+        recent_tickets = (await db.execute(ticket_stmt)).scalars().all()
+
         for t in recent_tickets:
         for t in recent_tickets:
-            reason_preview = t.reason[:20] + "..." if len(t.reason) > 20 else t.reason
+            reason_preview = (
+                t.reason[:20] + "..."
+                if t.reason and len(t.reason) > 20
+                else t.reason
+            )
             activities.append({
             activities.append({
                 "id": f"ticket_{t.id}",
                 "id": f"ticket_{t.id}",
                 "text": f"新工单 #{t.id}: {reason_preview}",
                 "text": f"新工单 #{t.id}: {reason_preview}",
                 "time": t.created_at,
                 "time": t.created_at,
                 "type": "ticket"
                 "type": "ticket"
             })
             })
-        
-        # 排序与时间格式
-        activities.sort(key=lambda x: x['time'], reverse=True)
+
+        # 排序 + 时间人性
+        activities.sort(key=lambda x: x["time"], reverse=True)
         activities = activities[:10]
         activities = activities[:10]
 
 
         now = datetime.now()
         now = datetime.now()
         for act in activities:
         for act in activities:
-            dt = act['time']
-            if not isinstance(dt, datetime):
-                 continue
+            dt = act["time"]
             diff = now - dt
             diff = now - dt
             if diff.days > 0:
             if diff.days > 0:
-                time_str = f"{diff.days}天前"
+                act["time"] = f"{diff.days}天前"
             elif diff.seconds > 3600:
             elif diff.seconds > 3600:
-                time_str = f"{diff.seconds // 3600}小时前"
+                act["time"] = f"{diff.seconds // 3600}小时前"
             elif diff.seconds > 60:
             elif diff.seconds > 60:
-                time_str = f"{diff.seconds // 60}分钟前"
+                act["time"] = f"{diff.seconds // 60}分钟前"
             else:
             else:
-                time_str = "刚刚"
-            act['time'] = time_str
+                act["time"] = "刚刚"
+
+        # --------------------------------------------------
+        # 返回结果
+        # --------------------------------------------------
 
 
         return {
         return {
             "stats": {
             "stats": {
                 "totalOrders": total_orders,
                 "totalOrders": total_orders,
-                "totalRevenue": total_revenue_cny, # 单位:分 (CNY)
+                "totalRevenue": total_revenue_cny,  # CNY 分
                 "activeUsers": active_users,
                 "activeUsers": active_users,
                 "pendingTickets": pending_tickets,
                 "pendingTickets": pending_tickets,
                 "successRate": success_rate_str
                 "successRate": success_rate_str
@@ -186,4 +253,4 @@ class StatisticsService:
             "revenue_trend": revenue_trend,
             "revenue_trend": revenue_trend,
             "product_dist": product_dist,
             "product_dist": product_dist,
             "recent_activities": activities
             "recent_activities": activities
-        }
+        }

+ 52 - 18
app/services/task_service.py

@@ -1,34 +1,56 @@
-import json
-from sqlalchemy.orm import Session
+# app/services/task_service.py
+
 from typing import List, Optional
 from typing import List, Optional
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, update
+
+from app.core.biz_exception import NotFoundError
 from app.models.task import Task
 from app.models.task import Task
 from app.schemas.task import TaskCreate, TaskUpdate
 from app.schemas.task import TaskCreate, TaskUpdate
 
 
 
 
 class TaskService:
 class TaskService:
+
+    # ======================
+    # 创建任务
+    # ======================
     @staticmethod
     @staticmethod
-    def create(db: Session, obj_in: TaskCreate) -> Task:
+    async def create(db: AsyncSession, obj_in: TaskCreate) -> Task:
         db_obj = Task(
         db_obj = Task(
             command=obj_in.command,
             command=obj_in.command,
             args=obj_in.args,
             args=obj_in.args,
             status=obj_in.status or 0,
             status=obj_in.status or 0,
         )
         )
         db.add(db_obj)
         db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
         return db_obj
 
 
+    # ======================
+    # 根据 ID 获取
+    # ======================
     @staticmethod
     @staticmethod
-    def get_by_id(db: Session, task_id: int) -> Optional[Task]:
-        obj = db.query(Task).filter(Task.id == task_id).first()
+    async def get_by_id(db: AsyncSession, task_id: int) -> Task:
+        stmt = select(Task).where(Task.id == task_id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("Task not exist")
             raise NotFoundError("Task not exist")
         return obj
         return obj
 
 
+    # ======================
+    # 更新任务
+    # ======================
     @staticmethod
     @staticmethod
-    def update(db: Session, task_id: int, obj_in: TaskUpdate) -> Optional[Task]:
-        db_obj = db.query(Task).filter(Task.id == task_id).first()
+    async def update(
+        db: AsyncSession,
+        task_id: int,
+        obj_in: TaskUpdate
+    ) -> Task:
+        stmt = select(Task).where(Task.id == task_id)
+        db_obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not db_obj:
         if not db_obj:
             raise NotFoundError("Task not exist")
             raise NotFoundError("Task not exist")
 
 
@@ -37,19 +59,31 @@ class TaskService:
         if obj_in.status is not None:
         if obj_in.status is not None:
             db_obj.status = obj_in.status
             db_obj.status = obj_in.status
 
 
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
         return db_obj
 
 
+    # ======================
+    # 获取待处理任务(分页)
+    # ======================
     @staticmethod
     @staticmethod
-    def get_pending(db: Session, command: str, page: int, size: int) -> List[Task]:
+    async def get_pending(
+        db: AsyncSession,
+        command: str,
+        page: int,
+        size: int
+    ) -> List[Task]:
         offset = page * size
         offset = page * size
-        return (
-            db.query(Task)
-            .filter(Task.command == command, Task.status == 0)
+
+        stmt = (
+            select(Task)
+            .where(
+                Task.command == command,
+                Task.status == 0
+            )
             .order_by(Task.create_at.asc())
             .order_by(Task.create_at.asc())
             .offset(offset)
             .offset(offset)
             .limit(size)
             .limit(size)
-            .all()
         )
         )
+
+        return (await db.execute(stmt)).scalars().all()

+ 17 - 12
app/services/telegram_service.py

@@ -1,21 +1,26 @@
-import json
-import time
-import requests
-from typing import List
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+# app/services/telegram_service.py
+
+import httpx
+from app.core.biz_exception import BizLogicError
 from app.schemas.telegram import TelegramIn
 from app.schemas.telegram import TelegramIn
 
 
 
 
 class TelegramService:
 class TelegramService:
-    def push_to_telegram(payload: TelegramIn):
+
+    @staticmethod
+    async def push_to_telegram(payload: TelegramIn):
         url = f"https://api.telegram.org/bot{payload.api_token}/sendMessage"
         url = f"https://api.telegram.org/bot{payload.api_token}/sendMessage"
-        payload = {
+
+        body = {
             "chat_id": payload.chat_id,
             "chat_id": payload.chat_id,
             "text": payload.message,
             "text": payload.message,
-            "parse_mode": "HTML"
+            "parse_mode": "HTML",
         }
         }
 
 
-        response = requests.post(url, json=payload, timeout=10)
-        if response.status_code != 200:
-            raise BizLogicError("Telegram push failed")
-    
+        async with httpx.AsyncClient(timeout=10) as client:
+            resp = await client.post(url, json=body)
+
+        if resp.status_code != 200:
+            raise BizLogicError(
+                f"Telegram push failed: {resp.status_code}, {resp.text}"
+            )

+ 238 - 78
app/services/ticket_service.py

@@ -1,12 +1,17 @@
-# app/services/ticket_service.py
 from datetime import datetime
 from datetime import datetime
+from typing import List, Optional
+
 from redis.asyncio import Redis
 from redis.asyncio import Redis
-from typing import List
-from sqlalchemy.orm import Session
-from app.utils.search import apply_keyword_search
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
 from app.utils.pagination import paginate
 from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
 from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
 from app.models.user import VasUser
 from app.models.user import VasUser
+from app.models.order import VasOrder
+from app.models.vas_task import VasTask
+from app.models.payment import VasPayment
 from app.models.ticket import VasTicket
 from app.models.ticket import VasTicket
 from app.models.ticket_message import VasTicketMessage
 from app.models.ticket_message import VasTicketMessage
 from app.schemas.ticket import VasTicketCreate
 from app.schemas.ticket import VasTicketCreate
@@ -14,135 +19,290 @@ from app.services.notification_service import NotificationService
 
 
 
 
 class TicketService:
 class TicketService:
-    
+
     @staticmethod
     @staticmethod
-    def create(db: Session, data: VasTicketCreate, current_user: VasUser, redis_client: Redis):
-        rec = VasTicket(**data.dict(), status='pending', created_at=datetime.utcnow())
-        rec.user_id = current_user.id
-        db.add(rec)
-        db.commit()
-        db.refresh(rec)
-        print(f"📧 send ticket created notification email")
-        NotificationService.create(
+    async def create(
+        db: AsyncSession,
+        data: VasTicketCreate,
+        current_user: VasUser,
+        redis_client: Redis
+    ):
+        ticket = VasTicket(
+            **data.dict(),
+            user_id=current_user.id,
+            status="pending",
+            created_at=datetime.utcnow(),
+            updated_at=datetime.utcnow(),
+        )
+
+        db.add(ticket)
+        await db.commit()
+        await db.refresh(ticket)
+
+        await NotificationService.create(
             redis_client=redis_client,
             redis_client=redis_client,
             ntype="ticket created",
             ntype="ticket created",
             user_id=current_user.id,
             user_id=current_user.id,
             channels=["email"],
             channels=["email"],
             template_id="ticket_created",
             template_id="ticket_created",
             payload={
             payload={
-                "ticket_id": rec.id,
-                "order_id": rec.order_id
-            }
+                "ticket_id": ticket.id,
+            },
         )
         )
-        return rec
+
+        return ticket
 
 
     @staticmethod
     @staticmethod
-    def update_status(db: Session, ticket_id, status, comment, admin_id):
-        ticket = db.query(VasTicket).filter_by(id=ticket_id).first()
+    async def update_status(
+        db: AsyncSession,
+        ticket_id: int,
+        status: str,
+        comment: str,
+        admin_id: str,
+    ) -> VasTicket:
+
+        # 🔒 锁住 ticket,防并发管理员操作
+        result = await db.execute(
+            select(VasTicket)
+            .where(VasTicket.id == ticket_id)
+            .with_for_update()
+        )
+        ticket = result.scalar_one_or_none()
+
         if not ticket:
         if not ticket:
             raise NotFoundError("Ticket not exist")
             raise NotFoundError("Ticket not exist")
+
         ticket.status = status
         ticket.status = status
         ticket.admin_comment = comment
         ticket.admin_comment = comment
+        ticket.updated_at = datetime.utcnow()
+
         db.add(
         db.add(
             VasTicketMessage(
             VasTicketMessage(
-                ticket_id=ticket_id,
+                ticket_id=ticket.id,
                 sender_type="admin",
                 sender_type="admin",
-                sender_id=admin_id,
-                content=comment
+                sender_id=str(admin_id),
+                content=comment,
+                created_at=datetime.utcnow(),
             )
             )
         )
         )
-        db.commit()
-        db.refresh(ticket)
+
+        if status == "resolved":
+            await TicketService._handle_resolution(db, ticket)
+
+        elif status == "rejected":
+            await TicketService._handle_rejection(db, ticket)
+
+        await db.commit()
+        await db.refresh(ticket)
         return ticket
         return ticket
-        
+
+    # =========================
+    # 工单解决逻辑
+    # =========================
+    @staticmethod
+    async def _handle_resolution(
+        db: AsyncSession,
+        ticket: VasTicket,
+    ) -> None:
+
+        if not ticket.order_id:
+            return
+
+        result = await db.execute(
+            select(VasOrder)
+            .where(VasOrder.id == ticket.order_id)
+            .with_for_update()
+        )
+        order = result.scalar_one_or_none()
+        if not order:
+            return
+
+        # ---------- 退款 ----------
+        if ticket.type == "refund":
+            order.status = "closed"
+
+            pay_res = await db.execute(
+                select(VasPayment).where(
+                    VasPayment.order_id == order.id,
+                    VasPayment.status.in_(["succeeded", "late_paid"]),
+                )
+            )
+            payment = pay_res.scalar_one_or_none()
+            if payment:
+                payment.status = "refunded"
+                payment.refunded_at = datetime.utcnow()
+                payment.refund_reason = ticket.reason
+
+            task_res = await db.execute(
+                select(VasTask).where(
+                    VasTask.order_id == order.id,
+                    VasTask.status.in_(["pending", "grabbed", "running"]),
+                )
+            )
+            for task in task_res.scalars().all():
+                task.status = "cancelled"
+
+        # ---------- 变更请求 ----------
+        elif ticket.type == "change_request":
+
+            # 1️⃣ 取消旧任务
+            task_res = await db.execute(
+                select(VasTask).where(
+                    VasTask.order_id == order.id,
+                    VasTask.status.in_(["pending", "grabbed", "running"]),
+                )
+            )
+            for task in task_res.scalars().all():
+                task.status = "cancelled"
+
+            # 2️⃣ 重新生成任务(幂等)
+            routing_res = await db.execute(
+                select(VasProductRouting).where(
+                    VasProductRouting.product_id == order.product_id,
+                    VasProductRouting.is_active == 1,
+                )
+            )
+            routings = routing_res.scalars().all()
+
+            for routing in routings:
+                exists_res = await db.execute(
+                    select(VasTask).where(
+                        VasTask.order_id == order.id,
+                        VasTask.routing_key == routing.routing_key,
+                        VasTask.script_version == routing.script_version,
+                    )
+                )
+                if exists_res.scalar_one_or_none():
+                    continue
+
+                db.add(
+                    VasTask(
+                        order_id=order.id,
+                        routing_key=routing.routing_key,
+                        script_version=routing.script_version,
+                        priority=10,
+                        status="pending",
+                        user_inputs=order.user_inputs,
+                        config=routing.config,
+                        attempt_count=0,
+                        notify_count=0,
+                        expire_at=datetime.utcnow() + timedelta(days=60),
+                        created_at=datetime.utcnow(),
+                    )
+                )
+
+    # =========================
+    # 工单拒绝逻辑
+    # =========================
+    @staticmethod
+    async def _handle_rejection(
+        db: AsyncSession,
+        ticket: VasTicket,
+    ) -> None:
+
+        if not ticket.order_id:
+            return
+
+        result = await db.execute(
+            select(VasOrder)
+            .where(VasOrder.id == ticket.order_id)
+            .with_for_update()
+        )
+        order = result.scalar_one_or_none()
+        if not order:
+            return
+
+        if ticket.type == "refund" and order.status == "refund_pending":
+            order.status = "paid"
+
     @staticmethod
     @staticmethod
-    def add_message(
-        db: Session,
+    async def add_message(
+        db: AsyncSession,
         ticket_id: int,
         ticket_id: int,
         sender_type: str,   # "user" | "admin" | "system"
         sender_type: str,   # "user" | "admin" | "system"
-        sender_id: str = None,
+        sender_id: Optional[str] = None,
         content: str = "",
         content: str = "",
-        attachments: dict = None
+        attachments: Optional[dict] = None,
     ):
     ):
-        # 1️⃣ 校验工单是否存在
-        ticket = db.query(VasTicket).filter(
-            VasTicket.id == ticket_id
-        ).first()
-
+        ticket = await db.get(VasTicket, ticket_id)
         if not ticket:
         if not ticket:
             raise NotFoundError("Ticket not exist")
             raise NotFoundError("Ticket not exist")
 
 
-        # 2️⃣ 创建消息
         message = VasTicketMessage(
         message = VasTicketMessage(
             ticket_id=ticket_id,
             ticket_id=ticket_id,
             sender_type=sender_type,
             sender_type=sender_type,
             sender_id=sender_id,
             sender_id=sender_id,
             content=content,
             content=content,
             attachments=attachments,
             attachments=attachments,
-            created_at=datetime.utcnow()
+            created_at=datetime.utcnow(),
         )
         )
 
 
-        # 3️⃣ 写入数据库
         db.add(message)
         db.add(message)
-
-        # 4️⃣ 更新工单更新时间(非常重要)
         ticket.updated_at = datetime.utcnow()
         ticket.updated_at = datetime.utcnow()
 
 
-        db.commit()
-        db.refresh(message)
-
+        await db.commit()
+        await db.refresh(message)
         return message
         return message
-    
+
     @staticmethod
     @staticmethod
-    def list_messages(
-        db: Session,
+    async def list_messages(
+        db: AsyncSession,
         ticket_id: int,
         ticket_id: int,
         page: int = 1,
         page: int = 1,
-        size: int = 20
+        size: int = 20,
     ):
     ):
-        # 1️⃣ 校验 ticket 是否存在
-        exists = db.query(VasTicket.id).filter(
-            VasTicket.id == ticket_id
-        ).first()
-
+        # 校验 ticket 是否存在
+        exists = await db.scalar(
+            select(VasTicket.id).where(VasTicket.id == ticket_id)
+        )
         if not exists:
         if not exists:
             raise NotFoundError("Ticket not exist")
             raise NotFoundError("Ticket not exist")
 
 
-        # 2️⃣ 查询消息(按时间正序)
-        query = (
-            db.query(VasTicketMessage)
-            .filter(VasTicketMessage.ticket_id == ticket_id)
+        stmt = (
+            select(VasTicketMessage)
+            .where(VasTicketMessage.ticket_id == ticket_id)
             .order_by(VasTicketMessage.created_at.desc())
             .order_by(VasTicketMessage.created_at.desc())
         )
         )
 
 
-        return paginate(query, page, size)
-    
+        return await paginate(db, stmt, page, size)
+
     @staticmethod
     @staticmethod
-    def list_by_user(db: Session, user_id: str, page: int=0, size: int=10, keyword: str=None):    
-        query = db.query(VasTicket).filter_by(user_id=user_id)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_by_user(
+        db: AsyncSession,
+        user_id: str,
+        page: int = 1,
+        size: int = 20,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasTicket).where(VasTicket.user_id == user_id)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasTicket,
             model=VasTicket,
             keyword=keyword,
             keyword=keyword,
-            fields=["order_id", "user_id", "reason", "admin_comment"]
+            fields=["order_id", "user_id", "reason", "admin_comment"],
         )
         )
-        query = query.order_by(
-            VasTicket.id.desc()
-        )
-        return paginate(query, page, size)  
-      
+
+        stmt = stmt.order_by(VasTicket.id.desc())
+
+        return await paginate(db, stmt, page, size)
+
     @staticmethod
     @staticmethod
-    def list_all(db: Session, page: int=0, size: int=10, keyword: str=None):    
-        query = db.query(VasTicket)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_all(
+        db: AsyncSession,
+        page: int = 1,
+        size: int = 20,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasTicket)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasTicket,
             model=VasTicket,
             keyword=keyword,
             keyword=keyword,
-            fields=["order_id", "user_id", "reason", "admin_comment"]
-        )
-        query = query.order_by(
-            VasTicket.id.desc()
+            fields=["order_id", "user_id", "reason", "admin_comment"],
         )
         )
-        return paginate(query, page, size)
+
+        stmt = stmt.order_by(VasTicket.id.desc())
+
+        return await paginate(db, stmt, page, size)

+ 125 - 56
app/services/troov_service.py

@@ -1,80 +1,149 @@
 import json
 import json
 import time
 import time
-import requests
-from typing import List
-from fastapi import Depends
+import random
+import asyncio
+import aiohttp
+from typing import List, Optional
+
+from redis.asyncio import Redis
+from starlette.concurrency import run_in_threadpool
 from app.schemas.troov import TroovRate
 from app.schemas.troov import TroovRate
-from app.utils.france_slot_api import *
+from app.utils.france_slot_api import troov_create_session_old
 from app.utils.proxy_utils import load_proxies_from_json
 from app.utils.proxy_utils import load_proxies_from_json
+from app.core.logger import logger
+
 
 
+# =========================================================
+# Redis 原子弹出 token(不使用 KEYS,避免阻塞)
+# =========================================================
 
 
-def pop_redis_value_token(redis_client):
-    lua_script = '''
-local keys = redis.call('keys', 'token:*')
+POP_TOKEN_LUA = """
+local cursor = "0"
 local max_ttl = -1
 local max_ttl = -1
 local max_key = nil
 local max_key = nil
 
 
-for _, key in ipairs(keys) do
-    local ttl = redis.call('ttl', key)
-    if ttl > 0 and ttl > max_ttl then
-        max_ttl = ttl
-        max_key = key
+repeat
+    local result = redis.call('SCAN', cursor, 'MATCH', 'token:*', 'COUNT', 50)
+    cursor = result[1]
+    local keys = result[2]
+
+    for _, key in ipairs(keys) do
+        local ttl = redis.call('TTL', key)
+        if ttl > max_ttl then
+            max_ttl = ttl
+            max_key = key
+        end
     end
     end
-end
+until cursor == "0"
 
 
 if max_key then
 if max_key then
-    local value = redis.call('get', max_key)
-    redis.call('del', max_key)
+    local value = redis.call('GET', max_key)
+    redis.call('DEL', max_key)
     return {max_key, value, max_ttl}
     return {max_key, value, max_ttl}
-else
-    return nil
 end
 end
-'''
-    result = redis_client.eval(lua_script, 0)
-    return result
 
 
-def fetch_rate(session_dic, date):
-    url = f"https://api.consulat.gouv.fr/api/team/621540d353069dec25bd0045/reservations/availability?name=Visas&date={date}&places=-5&matching=&maxCapacity=-5&sessionId={session_dic['session_id']}"
+return nil
+"""
+
+
+async def pop_redis_value_token(redis_client: Redis):
+    """
+    原子性获取 TTL 最大的 token 并删除
+    """
+    return await redis_client.eval(POP_TOKEN_LUA, 0)
+
+
+# =========================================================
+# 请求法国 Troov 接口(async,不阻塞)
+# =========================================================
+
+async def fetch_rate(session_dic: dict, date: str) -> str:
+    url = (
+        "https://api.consulat.gouv.fr/api/team/"
+        "621540d353069dec25bd0045/reservations/availability"
+        f"?name=Visas&date={date}&places=-5&matching=&maxCapacity=-5"
+        f"&sessionId={session_dic['session_id']}"
+    )
+
     headers = {
     headers = {
-        'accept': 'application/json, text/plain, */*',
-        'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8',
-        'origin': 'https://consulat.gouv.fr',
-        'referer': 'https://consulat.gouv.fr/en/ambassade-de-france-en-irlande/appointment?name=Visas',
-        'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36',
-        'x-gouv-app-id': session_dic['x_gouv_app_id'],
-        'x-gouv-web': 'fr.gouv.consulat'
+        "accept": "application/json, text/plain, */*",
+        "accept-language": "zh-CN,zh;q=0.9,en;q=0.8",
+        "origin": "https://consulat.gouv.fr",
+        "referer": "https://consulat.gouv.fr/en/ambassade-de-france-en-irlande/appointment?name=Visas",
+        "user-agent": (
+            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+            "AppleWebKit/537.36 (KHTML, like Gecko) "
+            "Chrome/141.0.0.0 Safari/537.36"
+        ),
+        "x-gouv-app-id": session_dic["x_gouv_app_id"],
+        "x-gouv-web": "fr.gouv.consulat",
     }
     }
-    try:
-        response = requests.get(url, headers=headers)
-        return response.text
-    except Exception as e:
-        return f'Exception info={e}'
 
 
-def get_rate_by_date(redis_client, date: str) -> List[TroovRate]:
+    timeout = aiohttp.ClientTimeout(total=15)
+
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with session.get(url, headers=headers) as resp:
+            return await resp.text()
+
+
+# =========================================================
+# 核心业务逻辑
+# =========================================================
+
+async def get_rate_by_date(
+    redis_client: Redis,
+    date: str
+) -> Optional[List[TroovRate]]:
     """
     """
-    核心业务逻辑:根据日期返回 Troov 预约信息
+    根据日期获取 Troov 预约信息
     """
     """
-    proxy_pools = ['oxylabs']
+
+    # ---------- 1️⃣ 加载代理 ----------
     proxies = []
     proxies = []
-    for pp in proxy_pools:
-        proxies = proxies + load_proxies_from_json("data/proxy_pool_config.json", pp)
-    
-    result = None
-    while True:
-        result = pop_redis_value_token(redis_client)
-        if not result:
-            time.sleep(1)
-            continue
-        break
-    body_str = result[1]
-    body = json.loads(body_str)
-    captcha = body.get("token")
-    
-    session_dic = troov_create_session_old(random.choice(proxies), captcha)
+    for pool in ("oxylabs",):
+        proxies.extend(
+            load_proxies_from_json("data/proxy_pool_config.json", pool)
+        )
+
+    if not proxies:
+        logger.error("Proxy pool is empty")
+        return None
+
+    # ---------- 2️⃣ 获取验证码 token(最多等待 30 秒) ----------
+    token_data = None
+    for _ in range(30):
+        token_data = await pop_redis_value_token(redis_client)
+        if token_data:
+            break
+        await asyncio.sleep(1)
+
+    if not token_data:
+        logger.warning("No captcha token available")
+        return None
+
+    _, body_str, ttl = token_data
+
+    try:
+        body = json.loads(body_str)
+        captcha_token = body.get("token")
+    except Exception:
+        logger.exception("Invalid captcha token format")
+        return None
+
+    # ---------- 3️⃣ 创建 Troov session(同步函数放线程池) ----------
+    proxy = random.choice(proxies)
+
+    session_dic = await run_in_threadpool(troov_create_session_old, proxy, captcha_token)
     if not session_dic:
     if not session_dic:
+        logger.warning("Failed to create Troov session")
         return None
         return None
-    logger.info(f'创建 session 成功: {session_dic}')
-    res = fetch_rate(session_dic, date)
-    return json.loads(res)
 
 
-        
+    logger.info(f"Troov session created: {session_dic}")
+
+    # ---------- 4️⃣ 请求预约数据 ----------
+    try:
+        response_text = await fetch_rate(session_dic, date)
+        return json.loads(response_text)
+    except Exception as e:
+        logger.error(f"Fetch rate failed: {e}")
+        return None

+ 70 - 34
app/services/user_service.py

@@ -1,19 +1,29 @@
 # app/services/user_service.py
 # app/services/user_service.py
+
+import uuid
 from datetime import datetime
 from datetime import datetime
-from typing import List
-from sqlalchemy.orm import Session
-from app.utils.search import apply_keyword_search
+from typing import List, Optional
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import NotFoundError
 from app.models.user import VasUser
 from app.models.user import VasUser
-from app.schemas.user import VasUserCreate, VasUserUpdate, VasUserSetProfiles, VasUserOut
+from app.schemas.user import (
+    VasUserCreate,
+    VasUserUpdate,
+    VasUserSetProfiles,
+    VasUserOut,
+)
 
 
 
 
 class UserService:
 class UserService:
-    
+
     @staticmethod
     @staticmethod
-    def create(db: Session, data: VasUserCreate):
-        uid = f'usr-{uuid.uuid4().hex[:8]}'
+    async def create(db: AsyncSession, data: VasUserCreate) -> VasUser:
+        uid = f"usr-{uuid.uuid4().hex[:8]}"
 
 
         user = VasUser(
         user = VasUser(
             id=uid,
             id=uid,
@@ -22,60 +32,86 @@ class UserService:
             phone=data.phone,
             phone=data.phone,
             preferred_language="en",
             preferred_language="en",
             timezone="Asia/Shanghai",
             timezone="Asia/Shanghai",
+            created_at=datetime.utcnow(),
         )
         )
+
         db.add(user)
         db.add(user)
-        db.commit()
-        return rec
-    
+        await db.commit()
+        await db.refresh(user)
+
+        return user
+
     @staticmethod
     @staticmethod
-    def get(db: Session, id: str):
-        user = db.query(VasUser).filter_by(id=id).first()
+    async def get(db: AsyncSession, id: str) -> VasUser:
+        stmt = select(VasUser).where(VasUser.id == id)
+        result = await db.execute(stmt)
+        user = result.scalar_one_or_none()
+
         if not user:
         if not user:
             raise NotFoundError("User not exist")
             raise NotFoundError("User not exist")
+
         return user
         return user
-    
+
     @staticmethod
     @staticmethod
-    def update(db: Session, id: str, payload: VasUserUpdate):
-        obj = db.query(VasUser).filter(VasUser.id == id).first()
+    async def update(
+        db: AsyncSession,
+        id: str,
+        payload: VasUserUpdate
+    ) -> VasUser:
+        stmt = select(VasUser).where(VasUser.id == id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("User not exist")
             raise NotFoundError("User not exist")
-        data = payload.dict(exclude_unset=True)  # ⭐ 关键
+
+        data = payload.dict(exclude_unset=True)
 
 
         for key, value in data.items():
         for key, value in data.items():
             setattr(obj, key, value)
             setattr(obj, key, value)
-        db.commit()
-        db.refresh(obj)
+
+        obj.updated_at = datetime.utcnow()
+
+        await db.commit()
+        await db.refresh(obj)
+
         return obj
         return obj
-    
+
     @staticmethod
     @staticmethod
-    def set_profiles(db: Session, user: VasUser, payload: VasUserSetProfiles):
+    async def set_profiles(
+        db: AsyncSession,
+        user: VasUser,
+        payload: VasUserSetProfiles
+    ) -> VasUser:
         """
         """
         更新用户资料(profile)
         更新用户资料(profile)
         """
         """
 
 
-        # 1️⃣ 字段赋值(显式,避免误更新)
         user.phone = payload.phone
         user.phone = payload.phone
         user.nickname = payload.nickname
         user.nickname = payload.nickname
         user.avatar_url = payload.avatar_url
         user.avatar_url = payload.avatar_url
-
-        # 2️⃣ 更新时间(可选,SQLAlchemy onupdate 也会生效)
         user.updated_at = datetime.utcnow()
         user.updated_at = datetime.utcnow()
 
 
-        # 3️⃣ 持久化
         db.add(user)
         db.add(user)
-        db.commit()
-        db.refresh(user)
+        await db.commit()
+        await db.refresh(user)
 
 
         return user
         return user
 
 
     @staticmethod
     @staticmethod
-    def list_all(db: Session, page: int = 1, size: int = 20, keyword: str=None):
-        query = db.query(VasUser)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_all(
+        db: AsyncSession,
+        page: int = 1,
+        size: int = 20,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasUser)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasUser,
             model=VasUser,
             keyword=keyword,
             keyword=keyword,
-            fields=["id", "email", "nickname", "phone"]
+            fields=["id", "email", "nickname", "phone"],
         )
         )
-        return paginate(query, page, size)
+
+        return await paginate(db, stmt, page, size)

+ 103 - 53
app/services/vas_task_service.py

@@ -1,88 +1,138 @@
 # app/services/task_service.py
 # app/services/task_service.py
-from sqlalchemy.orm import Session
-from typing import List
-from app.utils.search import apply_keyword_search
+
+from datetime import datetime
+from typing import List, Optional
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import NotFoundError
 from app.models.vas_task import VasTask
 from app.models.vas_task import VasTask
+from app.models.order import VasOrder
 from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
 from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
-from datetime import datetime
+
 
 
 class VasTaskService:
 class VasTaskService:
 
 
-    def create(db: Session, data: VasTaskCreate):
-        rec = VasTask(**data.dict(), status='pending', created_at=datetime.utcnow())
+    @staticmethod
+    async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask:
+        rec = VasTask(
+            **data.dict(),
+            status="pending",
+            created_at=datetime.utcnow(),
+        )
         db.add(rec)
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
 
 
-    def list_task(
-        db: Session,
-        status: str = None,
-        routing_key: str = None,
-        script_version: str = None,
-        keyword: str = None,
+    @staticmethod
+    async def list_task(
+        db: AsyncSession,
+        status: Optional[str] = None,
+        routing_key: Optional[str] = None,
+        script_version: Optional[str] = None,
+        keyword: Optional[str] = None,
         page: int = 0,
         page: int = 0,
         size: int = 10,
         size: int = 10,
     ):
     ):
-        query = db.query(VasTask)
-        
+        stmt = select(VasTask)
+
         if status:
         if status:
-            query = query.filter(VasTask.status == status)
-        
+            stmt = stmt.where(VasTask.status == status)
+
         if routing_key:
         if routing_key:
-            query = query.filter(VasTask.routing_key == routing_key)
+            stmt = stmt.where(VasTask.routing_key == routing_key)
 
 
         if script_version:
         if script_version:
-            query = query.filter(VasTask.script_version == script_version)
+            stmt = stmt.where(VasTask.script_version == script_version)
 
 
-        query = apply_keyword_search(
-            query=query,
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasTask,
             model=VasTask,
             keyword=keyword,
             keyword=keyword,
-            fields=["order_id", "routing_key", "user_inputs"]
+            fields=["order_id", "routing_key", "user_inputs"],
         )
         )
-        
-        query = query.order_by(
+
+        stmt = stmt.order_by(
             VasTask.priority.desc(),
             VasTask.priority.desc(),
-            VasTask.id.asc()
+            VasTask.id.asc(),
         )
         )
-        return paginate(query, page, size)
-    
-    def update(db: Session, id: int, payload: VasTaskUpdate):
-        obj = db.query(VasTask).filter(VasTask.id == id).first()
+
+        return await paginate(db, stmt, page, size)
+
+    @staticmethod
+    async def update(
+        db: AsyncSession,
+        id: int,
+        payload: VasTaskUpdate,
+    ) -> VasTask:
+        stmt = select(VasTask).where(VasTask.id == id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
         if not obj:
             raise NotFoundError("Task not exist")
             raise NotFoundError("Task not exist")
-        data = payload.dict(exclude_unset=True)  # ⭐ 关键
+
+        data = payload.dict(exclude_unset=True)
 
 
         for key, value in data.items():
         for key, value in data.items():
             setattr(obj, key, value)
             setattr(obj, key, value)
-        db.commit()
-        db.refresh(obj)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj
         return obj
-        
-    def get_active_task_by_order_id(db: Session, order_id:str):
-        recs = db.query(VasTask).filter(
+
+    @staticmethod
+    async def get_active_task_by_order_id(
+        db: AsyncSession,
+        order_id: str,
+    ) -> List[VasTask]:
+        stmt = select(VasTask).where(
             VasTask.status == "pending",
             VasTask.status == "pending",
-            VasTask.order_id == order_id
-        ).all()
-        return recs
-    
-    def return_to_queue(db: Session, id:int):
-        rec = db.query(VasTask).filter_by(id=id).first()
+            VasTask.order_id == order_id,
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    @staticmethod
+    async def return_to_queue(db: AsyncSession, id: int) -> VasTask:
+        stmt = select(VasTask).where(VasTask.id == id)
+        result = await db.execute(stmt)
+        rec = result.scalar_one_or_none()
+
         if not rec:
         if not rec:
             raise NotFoundError("Task not exist")
             raise NotFoundError("Task not exist")
-        rec.status = 'pending'
-        db.commit()
-        db.refresh(rec)
+
+        rec.status = "pending"
+        rec.attempt_count = (rec.attempt_count or 0) + 1
+
+        await db.commit()
+        await db.refresh(rec)
         return rec
         return rec
 
 
-    def manual_confirm(db: Session, id:int):
-        rec = db.query(VasTask).filter_by(id=id).first()
-        if not rec:
+    @staticmethod
+    async def manual_confirm(db: AsyncSession, id: int) -> VasTask:
+        stmt = select(VasTask).where(VasTask.id == id)
+        result = await db.execute(stmt)
+        task = result.scalar_one_or_none()
+
+        if not task:
             raise NotFoundError("Task not exist")
             raise NotFoundError("Task not exist")
-        rec.status = 'completed'
-        db.commit()
-        db.refresh(rec)
-        return rec
+
+        task.status = "completed"
+
+        order_stmt = select(VasOrder).where(VasOrder.id == task.order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
+        if not order:
+            raise NotFoundError("Order not exist")
+
+        order.status = "completed"
+
+        await db.commit()
+        await db.refresh(task)
+        return task

+ 113 - 116
app/services/webhook_service.py

@@ -1,61 +1,62 @@
 import re
 import re
 import json
 import json
 from datetime import datetime, timedelta
 from datetime import datetime, timedelta
-from sqlalchemy.orm import Session
 from typing import List, Optional
 from typing import List, Optional
-from decimal import Decimal, ROUND_HALF_UP
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from decimal import Decimal
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.order import VasOrder
 from app.models.order import VasOrder
 from app.models.vas_task import VasTask
 from app.models.vas_task import VasTask
-from app.models.product import VasProduct
 from app.models.product_routing import VasProductRouting
 from app.models.product_routing import VasProductRouting
 from app.models.payment_event import VasPaymentEvent
 from app.models.payment_event import VasPaymentEvent
 from app.models.payment import VasPayment
 from app.models.payment import VasPayment
 from app.models.payment_qr import VasPaymentQR
 from app.models.payment_qr import VasPaymentQR
 from app.schemas.webhook import SMSHelperWebhookPayload, PaymentWebhookOut
 from app.schemas.webhook import SMSHelperWebhookPayload, PaymentWebhookOut
 
 
+
 class WebhookService:
 class WebhookService:
-    
+
+    # =========================================================
+    # 内部方法:创建 Task(幂等)
+    # =========================================================
     @staticmethod
     @staticmethod
-    def _create_task_if_not_exists(
-        db: Session,
-        order: VasOrder
-    ):
-        routings = (
-            db.query(VasProductRouting)
-            .filter(
-                VasProductRouting.product_id == order.product_id,
-                VasProductRouting.is_active == 1
-            )
-            .all()
+    async def _create_task_if_not_exists(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> List[VasTask]:
+
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == order.product_id,
+            VasProductRouting.is_active == 1,
         )
         )
+        result = await db.execute(stmt)
+        routings = result.scalars().all()
 
 
         if not routings:
         if not routings:
             return []
             return []
 
 
-        created_tasks = []
+        created_tasks: List[VasTask] = []
 
 
         for routing in routings:
         for routing in routings:
-
-            # ---------- 2. 幂等判断 ----------
-            exists = (
-                db.query(VasTask)
-                .filter(
-                    VasTask.order_id == order.id,
-                    VasTask.routing_key == routing.routing_key,
-                    VasTask.script_version == routing.script_version,
-                )
-                .first()
+            exists_stmt = select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.routing_key == routing.routing_key,
+                VasTask.script_version == routing.script_version,
             )
             )
+            exists_result = await db.execute(exists_stmt)
+            exists = exists_result.scalar_one_or_none()
+
             if exists:
             if exists:
                 continue
                 continue
 
 
-            # ---------- 3. 创建 task ----------
             task = VasTask(
             task = VasTask(
                 order_id=order.id,
                 order_id=order.id,
                 routing_key=routing.routing_key,
                 routing_key=routing.routing_key,
                 script_version=routing.script_version,
                 script_version=routing.script_version,
-                priority=10,
+                priority=routing.priority,
                 status="pending",
                 status="pending",
                 user_inputs=order.user_inputs,
                 user_inputs=order.user_inputs,
                 config=routing.config,
                 config=routing.config,
@@ -66,24 +67,27 @@ class WebhookService:
             )
             )
             db.add(task)
             db.add(task)
             created_tasks.append(task)
             created_tasks.append(task)
-            
+
         return created_tasks
         return created_tasks
 
 
-    
+    # =========================================================
+    # SMSHelper 微信 / 支付宝 收款 webhook
+    # =========================================================
     @staticmethod
     @staticmethod
-    def smshelper_payment_webhook(db: Session, payload: SMSHelperWebhookPayload):
-        """
-        webhook payload 示例:
-        title=【微信】微信支付
-        content=【SM-E5260】个人收款码到账¥0.01
-        """
+    async def smshelper_payment_webhook(
+        db: AsyncSession,
+        payload: SMSHelperWebhookPayload,
+    ) -> Optional[PaymentWebhookOut]:
 
 
         title = payload.title
         title = payload.title
         content = payload.content
         content = payload.content
+
         if "微信" in title:
         if "微信" in title:
             provider = "wechat"
             provider = "wechat"
         elif "支付宝" in title:
         elif "支付宝" in title:
             provider = "alipay"
             provider = "alipay"
+        else:
+            raise BizLogicError("Unknown payment provider")
 
 
         device_match = re.search(r"【(.+?)】", content)
         device_match = re.search(r"【(.+?)】", content)
         device_id = device_match.group(1) if device_match else None
         device_id = device_match.group(1) if device_match else None
@@ -94,7 +98,7 @@ class WebhookService:
 
 
         amount_yuan = Decimal(amount_match.group(1))
         amount_yuan = Decimal(amount_match.group(1))
         amount_cent = int(amount_yuan * 100)
         amount_cent = int(amount_yuan * 100)
-        
+
         event = VasPaymentEvent(
         event = VasPaymentEvent(
             provider=provider,
             provider=provider,
             event_type="payment_received",
             event_type="payment_received",
@@ -104,59 +108,57 @@ class WebhookService:
             parsed_currency="CNY",
             parsed_currency="CNY",
             parsed_device=device_id,
             parsed_device=device_id,
             raw_payload=payload.dict(),
             raw_payload=payload.dict(),
-            status="received"
+            status="received",
         )
         )
         db.add(event)
         db.add(event)
-        db.commit()
-        db.refresh(event)
-        
-        payment_qr = (
-            db.query(VasPaymentQR)
-            .filter(
-                VasPaymentQR.provider == provider,
-                VasPaymentQR.device == device_id,
-                VasPaymentQR.is_active == 1
-            )
-            .first()
+        await db.commit()
+        await db.refresh(event)
+
+        # ---------- 查找 QR ----------
+        qr_stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider,
+            VasPaymentQR.device == device_id,
+            VasPaymentQR.is_active == 1,
         )
         )
-        
+        qr_result = await db.execute(qr_stmt)
+        payment_qr = qr_result.scalar_one_or_none()
+
         if not payment_qr:
         if not payment_qr:
             event.status = "failed"
             event.status = "failed"
             event.error_message = "QR not found"
             event.error_message = "QR not found"
-            db.commit()
+            await db.commit()
             raise BizLogicError("QR not found")
             raise BizLogicError("QR not found")
 
 
-        payment = (
-            db.query(VasPayment)
-            .filter(
+        # ---------- 查找 payment ----------
+        pay_stmt = (
+            select(VasPayment)
+            .where(
                 VasPayment.provider == provider,
                 VasPayment.provider == provider,
                 VasPayment.amount == amount_cent,
                 VasPayment.amount == amount_cent,
                 VasPayment.qr_id == payment_qr.id,
                 VasPayment.qr_id == payment_qr.id,
-                VasPayment.status == "pending"
+                VasPayment.status == "pending",
             )
             )
             .order_by(VasPayment.created_at.desc())
             .order_by(VasPayment.created_at.desc())
-            .first()
         )
         )
-        
+        pay_result = await db.execute(pay_stmt)
+        payment = pay_result.scalar_one_or_none()
+
         if not payment:
         if not payment:
             event.status = "failed"
             event.status = "failed"
             event.error_message = "No matching pending payment"
             event.error_message = "No matching pending payment"
-            db.commit()
+            await db.commit()
             raise BizLogicError("Payment not found")
             raise BizLogicError("Payment not found")
+
         if payment.status in ("succeeded", "late_paid"):
         if payment.status in ("succeeded", "late_paid"):
             event.status = "duplicate"
             event.status = "duplicate"
             event.matched_payment_id = payment.id
             event.matched_payment_id = payment.id
             event.matched_order_id = payment.order_id
             event.matched_order_id = payment.order_id
-            db.commit()
+            await db.commit()
             return None
             return None
 
 
         now = datetime.utcnow()
         now = datetime.utcnow()
-        if payment.expire_at and now > payment.expire_at:
-            payment.status = "late_paid"
-        else:
-            payment.status = "succeeded"
-                            
-        # ---------- 写入原始 payload ----------
+        payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
+
         payment.provider_payload = {
         payment.provider_payload = {
             "title": title,
             "title": title,
             "content": content,
             "content": content,
@@ -164,47 +166,51 @@ class WebhookService:
             "received_at": now.isoformat(),
             "received_at": now.isoformat(),
         }
         }
 
 
-        order = db.query(VasOrder).filter(VasOrder.id == payment.order_id).first()
+        order_stmt = select(VasOrder).where(VasOrder.id == payment.order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
         if order and order.status != "paid":
         if order and order.status != "paid":
             order.status = "paid"
             order.status = "paid"
-            
-        WebhookService._create_task_if_not_exists(db, order)
+
+        await WebhookService._create_task_if_not_exists(db, order)
 
 
         event.status = "applied"
         event.status = "applied"
         event.matched_payment_id = payment.id
         event.matched_payment_id = payment.id
         event.matched_order_id = payment.order_id
         event.matched_order_id = payment.order_id
 
 
-        db.commit()
-        db.refresh(payment)
-        
+        await db.commit()
+        await db.refresh(payment)
+
         return PaymentWebhookOut(
         return PaymentWebhookOut(
             status=True,
             status=True,
             order_id=order.id,
             order_id=order.id,
             user_id=order.user_id,
             user_id=order.user_id,
             payment_id=payment.id,
             payment_id=payment.id,
-            notify=True
+            notify=True,
         )
         )
-        
+
+    # =========================================================
+    # Stripe webhook
+    # =========================================================
     @staticmethod
     @staticmethod
-    def stripe_payment_webhook(db: Session, event):
-        """
-        Stripe webhook handler
-        """
+    async def stripe_payment_webhook(
+        db: AsyncSession,
+        event: dict,
+    ) -> Optional[PaymentWebhookOut]:
 
 
         event_id = event["id"]
         event_id = event["id"]
         event_type = event["type"]
         event_type = event["type"]
         data = event["data"]["object"]
         data = event["data"]["object"]
-        # ---------- 1. 幂等(事件级) ----------
-        existed_event = (
-            db.query(VasPaymentEvent)
-            .filter(VasPaymentEvent.provider == "stripe")
-            .filter(VasPaymentEvent.event_id == event_id)
-            .first()
+
+        existed_stmt = select(VasPaymentEvent).where(
+            VasPaymentEvent.provider == "stripe",
+            VasPaymentEvent.event_id == event_id,
         )
         )
-        if existed_event:
+        existed_result = await db.execute(existed_stmt)
+        if existed_result.scalar_one_or_none():
             return None
             return None
 
 
-        # ---------- 2. 只处理关心的事件 ----------
         if event_type != "checkout.session.completed":
         if event_type != "checkout.session.completed":
             db.add(
             db.add(
                 VasPaymentEvent(
                 VasPaymentEvent(
@@ -215,10 +221,9 @@ class WebhookService:
                     created_at=datetime.utcnow(),
                     created_at=datetime.utcnow(),
                 )
                 )
             )
             )
-            db.commit()
+            await db.commit()
             return None
             return None
 
 
-        # ---------- 3. 解析 metadata ----------
         metadata = data.get("metadata", {})
         metadata = data.get("metadata", {})
         payment_id = metadata.get("payment_id")
         payment_id = metadata.get("payment_id")
         order_id = metadata.get("order_id")
         order_id = metadata.get("order_id")
@@ -226,56 +231,50 @@ class WebhookService:
         if not payment_id or not order_id:
         if not payment_id or not order_id:
             raise BizLogicError("Missing payment_id or order_id in metadata")
             raise BizLogicError("Missing payment_id or order_id in metadata")
 
 
-        # ---------- 4. 查找 payment(业务级幂等) ----------
-        payment = (
-            db.query(VasPayment)
-            .filter(VasPayment.id == int(payment_id))
-            .first()
-        )
+        pay_stmt = select(VasPayment).where(VasPayment.id == int(payment_id))
+        pay_result = await db.execute(pay_stmt)
+        payment = pay_result.scalar_one_or_none()
+
         if not payment:
         if not payment:
             raise NotFoundError("Payment not found")
             raise NotFoundError("Payment not found")
 
 
         if payment.status == "succeeded":
         if payment.status == "succeeded":
-            # 已处理过
             db.add(
             db.add(
                 VasPaymentEvent(
                 VasPaymentEvent(
                     provider="stripe",
                     provider="stripe",
                     event_id=event_id,
                     event_id=event_id,
                     event_type=event_type,
                     event_type=event_type,
-                    payload=event,
                     payment_id=payment.id,
                     payment_id=payment.id,
                     created_at=datetime.utcnow(),
                     created_at=datetime.utcnow(),
                 )
                 )
             )
             )
-            db.commit()
+            await db.commit()
             return None
             return None
 
 
-        # ---------- 5. 金额校验 ----------
-        paid_amount = data["amount_total"]  # 单位:cent
+        paid_amount = data["amount_total"]
         paid_currency = data["currency"].upper()
         paid_currency = data["currency"].upper()
 
 
         if paid_amount != payment.amount or paid_currency != payment.currency:
         if paid_amount != payment.amount or paid_currency != payment.currency:
-            raise BizLogicError(f"Amount mismatch, expected {payment.amount} {payment.currency}, got {paid_amount} {paid_currency}")
+            raise BizLogicError(
+                f"Amount mismatch, expected {payment.amount} {payment.currency}, "
+                f"got {paid_amount} {paid_currency}"
+            )
 
 
-        # ---------- 6. 判断是否超时 ----------
         now = datetime.utcnow()
         now = datetime.utcnow()
-        if payment.expire_at and now > payment.expire_at:
-            payment.status = "late_paid"
-        else:
-            payment.status = "succeeded"
-            
+        payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
         payment.provider_payload = event
         payment.provider_payload = event
         payment.updated_at = now
         payment.updated_at = now
 
 
-        # ---------- 7. 更新 order ----------
-        order = db.query(VasOrder).filter(VasOrder.id == order_id).first()
+        order_stmt = select(VasOrder).where(VasOrder.id == order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
         if order and order.status != "paid":
         if order and order.status != "paid":
             order.status = "paid"
             order.status = "paid"
             order.updated_at = now
             order.updated_at = now
-            
-        WebhookService._create_task_if_not_exists(db, order)
 
 
-        # ---------- 8. 写 payment_event ----------
+        await WebhookService._create_task_if_not_exists(db, order)
+
         db.add(
         db.add(
             VasPaymentEvent(
             VasPaymentEvent(
                 provider="stripe",
                 provider="stripe",
@@ -288,15 +287,13 @@ class WebhookService:
             )
             )
         )
         )
 
 
-        db.commit()
-        db.refresh(payment)
+        await db.commit()
+        await db.refresh(payment)
 
 
         return PaymentWebhookOut(
         return PaymentWebhookOut(
             status=True,
             status=True,
             order_id=order.id,
             order_id=order.id,
             user_id=order.user_id,
             user_id=order.user_id,
             payment_id=payment.id,
             payment_id=payment.id,
-            notify=True
+            notify=True,
         )
         )
-
-   

+ 31 - 12
app/services/wechat_service.py

@@ -1,23 +1,42 @@
-import json
-import time
-import requests
-from typing import List
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+import httpx
+from app.core.biz_exception import BizLogicError
 from app.schemas.wechat import WechatIn
 from app.schemas.wechat import WechatIn
 
 
 
 
 class WechatService:
 class WechatService:
-    def push_to_wechat(payload: WechatIn):
+    @staticmethod
+    async def push_to_wechat(payload: WechatIn):
         """
         """
-        企业微信 WebHook 格式:
+        企业微信 WebHook 推送(Async 版)
+
+        WebHook 格式:
         https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY
         https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY
         """
         """
         url = f"https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key={payload.api_token}"
         url = f"https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key={payload.api_token}"
-        payload = {"msgtype": "text", "text": {"content": payload.message}}
 
 
-        response = requests.post(url, json=payload, timeout=10)
+        body = {
+            "msgtype": "text",
+            "text": {
+                "content": payload.message
+            }
+        }
+
+        try:
+            async with httpx.AsyncClient(timeout=10) as client:
+                response = await client.post(url, json=body)
+
+        except httpx.RequestError as e:
+            raise BizLogicError(f"Wechat push request error: {e}")
+
+        if response.status_code != 200:
+            raise BizLogicError(
+                f"Wechat push failed, http_status={response.status_code}"
+            )
+
         data = response.json()
         data = response.json()
+        if data.get("errcode") != 0:
+            raise BizLogicError(
+                f"Wechat push failed, errcode={data.get('errcode')}, errmsg={data.get('errmsg')}"
+            )
 
 
-        if response.status_code != 200 or data.get("errcode") != 0:
-            # logger.error(f"企业微信推送失败: {response.text}")
-            raise BizLogicError("Wechat push failed")
+        return True

+ 256 - 0
app/tasks/notification_task.py

@@ -0,0 +1,256 @@
+
+import asyncio
+from typing import Dict, Any
+from redis.asyncio import Redis
+from app.services.wechat_service import WechatService
+from app.services.email_authorizations_service import EmailAuthorizationService
+from app.utils.redis_utils import redis_qpop
+
+async def notification_consumer(redis_client: Redis):
+    """
+    异步消费 Redis 队列 vas_notification_queue
+    """
+    queue_name = "vas_notification_queue"
+    return
+    while True:
+        try:
+            # 阻塞获取队列消息
+            message: Dict[str, Any] = await redis_qpop(redis_client, queue_name, timeout=5)
+            if not message:
+                await asyncio.sleep(1)  # 队列为空,休眠
+                continue
+
+            channels = message.get("channels", [])
+            template_id = message.get("template_id")
+            payload = message.get("payload", {})
+            user_id = message.get("user_id")
+
+            # 按渠道发送
+            if "email" in channels:
+                # EmailService.create(user_id, template_id, payload) 是你自己实现的发送逻辑
+                await EmailAuthorizationService.send(user_id=user_id, template_id=template_id, payload=payload)
+
+            if "wechat" in channels:
+                api_token = payload.get("api_token")
+                content = payload.get("message") or payload.get("content")
+                if api_token and content:
+                    await WechatService.push_to_wechat({"api_token": api_token, "message": content})
+
+            print(f"✅ Notification sent: {message.get('notification_id')}")
+
+        except Exception as e:
+            print(f"⚠️ Notification consumer error: {e}")
+            await asyncio.sleep(1)  # 避免异常循环过快
+
+def template_for_bind_email(payload):
+    
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Email Verification</title>
+        <style>
+            body { font-family: Arial, sans-serif; background-color: #f4f4f4; margin: 0; padding: 0; }
+            .container { max-width: 600px; margin: 20px auto; background-color: #ffffff; padding: 30px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
+            .code { font-size: 24px; font-weight: bold; letter-spacing: 5px; color: #333; background-color: #f0f0f0; padding: 15px; text-align: center; border-radius: 4px; margin: 20px 0; }
+            .footer { font-size: 12px; color: #888; margin-top: 30px; text-align: center; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <h2>Verify your email address</h2>
+            <p>Hello,</p>
+            <p>You requested to bind this email address to your <strong>{{app_name}}</strong> account. Please use the verification code below to proceed:</p>
+            
+            <div class="code">{{code}}</div>
+            
+            <p>This code will expire in <strong>{{expiration_time}}</strong>.</p>
+            <p>If you did not request this change, please ignore this email.</p>
+            
+            <br>
+            <p>Best regards,<br>The {{app_name}} Team</p>
+            
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+
+def template_for_reset_pwd(payload):
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Reset Password</title>
+        <style>
+            body { font-family: 'Helvetica Neue', Arial, sans-serif; background-color: #f9f9f9; margin: 0; padding: 0; color: #333; }
+            .container { max-width: 600px; margin: 40px auto; background-color: #ffffff; padding: 40px; border-radius: 8px; box-shadow: 0 4px 10px rgba(0,0,0,0.05); }
+            .header { border-bottom: 1px solid #eee; padding-bottom: 20px; margin-bottom: 30px; }
+            .header h2 { margin: 0; color: #333; }
+            .code { font-size: 32px; font-weight: bold; letter-spacing: 5px; color: #2563eb; background-color: #eff6ff; padding: 20px; text-align: center; border-radius: 8px; margin: 30px 0; border: 1px solid #dbeafe; }
+            .warning { background-color: #fff7ed; border-left: 4px solid #f97316; padding: 15px; font-size: 14px; color: #c2410c; margin-top: 30px; }
+            .footer { font-size: 12px; color: #999; margin-top: 40px; text-align: center; border-top: 1px solid #eee; padding-top: 20px; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <div class="header">
+                <h2>Password Reset Request</h2>
+            </div>
+            
+            <p>Hello,</p>
+            <p>We received a request to reset the password for your <strong>{{app_name}}</strong> account. Please use the following code to verify your identity:</p>
+            
+            <div class="code">{{code}}</div>
+            
+            <p>This code is valid for <strong>{{expiration_time}}</strong>.</p>
+            
+            <div class="warning">
+                <strong>Security Tip:</strong> If you did not request a password reset, please ignore this email. No changes will be made to your account.
+            </div>
+            
+            <br>
+            <p>Best regards,<br>The {{app_name}} Team</p>
+            
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.<br>
+                This is an automated message, please do not reply.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+    
+def template_for_login_credentials(payload):
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Your Account Details</title>
+        <style>
+            body { font-family: 'Helvetica Neue', Arial, sans-serif; background-color: #f4f6f8; margin: 0; padding: 0; color: #333; }
+            .container { max-width: 600px; margin: 40px auto; background-color: #ffffff; padding: 40px; border-radius: 8px; box-shadow: 0 4px 12px rgba(0,0,0,0.05); }
+            .header { text-align: center; border-bottom: 1px solid #eee; padding-bottom: 20px; margin-bottom: 30px; }
+            .header h1 { font-size: 24px; color: #1a1a1a; margin: 0; }
+            .creds-box { background-color: #f0f7ff; border: 1px solid #dbeafe; border-radius: 8px; padding: 20px; margin: 25px 0; }
+            .creds-item { margin-bottom: 10px; font-size: 16px; }
+            .creds-label { font-weight: bold; color: #555; width: 100px; display: inline-block; }
+            .creds-value { font-family: 'Courier New', Courier, monospace; font-weight: bold; color: #2563eb; }
+            .btn { display: block; width: 200px; margin: 30px auto; background-color: #2563eb; color: #ffffff !important; text-align: center; padding: 12px 0; border-radius: 6px; text-decoration: none; font-weight: bold; }
+            .note { font-size: 13px; color: #666; background-color: #fff4e5; padding: 10px; border-radius: 4px; border-left: 4px solid #f97316; }
+            .footer { font-size: 12px; color: #999; margin-top: 40px; text-align: center; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <div class="header">
+                <h1>Welcome to {{app_name}}</h1>
+            </div>
+            
+            <p>Dear User,</p>
+            <p>Your account has been successfully set up. Below are your temporary login credentials.</p>
+            
+            <div class="creds-box">
+                <div class="creds-item">
+                    <span class="creds-label">Username:</span>
+                    <span class="creds-value">{{username}}</span>
+                </div>
+                <div class="creds-item">
+                    <span class="creds-label">Password:</span>
+                    <span class="creds-value">{{password}}</span>
+                </div>
+            </div>
+
+            <div class="note">
+                <strong>Important:</strong> For your security, please change your password immediately after logging in.
+            </div>
+
+            <a href="{{login_url}}" class="btn">Log In Now</a>
+            
+            <p style="text-align: center; font-size: 14px;">
+                Or copy this link: <a href="{{login_url}}">{{login_url}}</a>
+            </p>
+
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.<br>
+                If you did not request this account, please contact support.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+    
+def template_ticket_open(payload):
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Ticket Created</title>
+        <style>
+            body { font-family: 'Helvetica Neue', Arial, sans-serif; background-color: #f4f6f8; margin: 0; padding: 0; color: #333; }
+            .container { max-width: 600px; margin: 40px auto; background-color: #ffffff; padding: 40px; border-radius: 8px; box-shadow: 0 4px 12px rgba(0,0,0,0.05); }
+            .header { border-bottom: 1px solid #eee; padding-bottom: 20px; margin-bottom: 30px; }
+            .header h1 { font-size: 22px; color: #1a1a1a; margin: 0; }
+            .ticket-info { background-color: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 20px; margin: 20px 0; }
+            .info-row { margin-bottom: 10px; display: flex; justify-content: space-between; }
+            .info-label { color: #64748b; font-size: 14px; }
+            .info-value { font-weight: bold; color: #0f172a; font-size: 14px; }
+            .btn { display: block; width: 200px; margin: 30px auto; background-color: #2563eb; color: #ffffff !important; text-align: center; padding: 12px 0; border-radius: 6px; text-decoration: none; font-weight: bold; font-size: 14px; }
+            .footer { font-size: 12px; color: #94a3b8; margin-top: 40px; text-align: center; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <div class="header">
+                <h1>Support Request Received</h1>
+            </div>
+            
+            <p>Hello {{username}},</p>
+            <p>We wanted to let you know that we've received your request. Our team is currently reviewing the details.</p>
+            
+            <div class="ticket-info">
+                <div class="info-row">
+                    <span class="info-label">Ticket ID:</span>
+                    <span class="info-value">#{{ticket_id}}</span>
+                </div>
+                <div class="info-row">
+                    <span class="info-label">Type:</span>
+                    <span class="info-value">{{ticket_type}}</span>
+                </div>
+                <div class="info-row" style="margin-bottom: 0;">
+                    <span class="info-label">Time:</span>
+                    <span class="info-value">{{created_at}}</span>
+                </div>
+            </div>
+
+            <p>We usually reply within 24 hours. You will receive an email notification when our agent replies.</p>
+
+            <a href="{{ticket_url}}" class="btn">View Ticket Details</a>
+            
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.<br>
+                Please do not reply to this email directly.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+    
+def template_confirm_payment(payload):
+    template = {
+        "touser": "ADMIN_USER_ID",
+        "msgtype": "textcard",
+        "agentid": 1000001,
+        "textcard": {
+            "title": "💰 待确认:收到新的手动转账",
+            "description": "<div class=\"gray\">2025-12-31 10:30:00</div> <br>订单号:ORD-20251231-001<br>用户:user@example.com<br><div class=\"highlight\">金额:¥ 3,500.00</div><br>请核实资金到账后,点击卡片确认收款。",
+            "url": "https://admin.visafly.com/payment/confirm?payment_id=123&token=secure_token_abc",
+            "btntxt": "立即确认"
+        }
+    }

+ 18 - 9
app/utils/pagination.py

@@ -1,24 +1,33 @@
-from sqlalchemy.orm import Query
-from sqlalchemy.orm import Session
+from typing import Any, Dict
+from sqlalchemy import select, func
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.sql import Select
 
 
-def paginate(
-    query: Query,
+
+async def paginate(
+    db: AsyncSession,
+    stmt: Select,
     page: int = 1,
     page: int = 1,
     size: int = 20,
     size: int = 20,
-):
+) -> Dict[str, Any]:
     if page < 1:
     if page < 1:
         page = 1
         page = 1
     if size < 1:
     if size < 1:
         size = 20
         size = 20
 
 
-    total = query.count()
+    # ---------- 1️⃣ 查询总数 ----------
+    count_stmt = select(func.count()).select_from(
+        stmt.subquery()
+    )
+    total = await db.scalar(count_stmt) or 0
 
 
-    items = (
-        query
+    # ---------- 2️⃣ 查询分页数据 ----------
+    result = await db.execute(
+        stmt
         .offset((page - 1) * size)
         .offset((page - 1) * size)
         .limit(size)
         .limit(size)
-        .all()
     )
     )
+    items = result.scalars().all()
 
 
     return {
     return {
         "items": items,
         "items": items,

+ 28 - 14
app/utils/redis_utils.py

@@ -1,22 +1,36 @@
 import json
 import json
-import pytz
-import resource
+from typing import Optional
 from redis.asyncio import Redis
 from redis.asyncio import Redis
-from datetime import datetime, timedelta
 
 
 
 
-def redis_qpush(redis_client, qname: str, data: dict, max_len: int = 30):
-    """向队列右侧推入数据,并限制队列最大长度"""
+async def redis_qpush(
+    redis_client: Redis,
+    qname: str,
+    data: dict,
+    max_len: int = 30,
+):
+    """
+    向队列右侧推入数据,并限制队列最大长度(Async 版)
+    """
     data_string = json.dumps(data)
     data_string = json.dumps(data)
-    pipe = redis_client.pipeline()
+
+    pipe = redis_client.pipeline(transaction=True)
     pipe.rpush(qname, data_string)
     pipe.rpush(qname, data_string)
-    pipe.ltrim(qname, -max_len, -1)  # 只保留右侧 max_len 个元素
-    pipe.execute()
+    pipe.ltrim(qname, -max_len, -1)
+    await pipe.execute()
+
 
 
-def redis_qpop(redis_client, qname:str, timeout: int = 5):
-    message = redis_client.blpop(qname, timeout=timeout)
+async def redis_qpop(
+    redis_client: Redis,
+    qname: str,
+    timeout: int = 5,
+) -> Optional[dict]:
+    """
+    从队列左侧阻塞弹出数据(Async 版)
+    """
+    message = await redis_client.blpop(qname, timeout=timeout)
     if message is None:
     if message is None:
-        return None  # 队列为空,直接返回
-    message_string = message[1]
-    data = json.loads(message_string)
-    return data
+        return None
+
+    _, message_string = message
+    return json.loads(message_string)

+ 18 - 4
app/utils/search.py

@@ -1,15 +1,29 @@
+from typing import List, Optional
 from sqlalchemy import or_
 from sqlalchemy import or_
-from typing import List
+from sqlalchemy.sql import Select
 
 
-def apply_keyword_search(query, model, keyword: str, fields: List[str]):
+
+def apply_keyword_search_stmt(
+    stmt: Select,
+    model,
+    keyword: Optional[str],
+    fields: List[str],
+) -> Select:
+    """
+    Async / SQLAlchemy 2.0 Select 版本的关键字搜索
+    """
     if not keyword:
     if not keyword:
-        return query
+        return stmt
 
 
     like = f"%{keyword}%"
     like = f"%{keyword}%"
 
 
     conditions = [
     conditions = [
         getattr(model, field).ilike(like)
         getattr(model, field).ilike(like)
         for field in fields
         for field in fields
+        if hasattr(model, field)
     ]
     ]
 
 
-    return query.filter(or_(*conditions))
+    if not conditions:
+        return stmt
+
+    return stmt.where(or_(*conditions))

+ 1 - 1
app/utils/validation_utils.py

@@ -2,7 +2,7 @@ from jsonschema import validate, ValidationError
 from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
 from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
 
 
 def validate_user_inputs(schema_json: dict, user_inputs: dict):
 def validate_user_inputs(schema_json: dict, user_inputs: dict):
-    print(f'schema_json={schema_json}, user_inputs={user_inputs}')
+    # print(f'schema_json={schema_json}, user_inputs={user_inputs}')
     try:
     try:
         validate(instance=user_inputs, schema=schema_json)
         validate(instance=user_inputs, schema=schema_json)
     except ValidationError as e:
     except ValidationError as e:

+ 11 - 0
app/utils/wrappers.py

@@ -0,0 +1,11 @@
+from functools import wraps
+
+def after_call(after_func):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            result = func(*args, **kwargs)
+            after_func(*args, **kwargs)
+            return result
+        return wrapper
+    return decorator

+ 52 - 13
starter.py

@@ -1,35 +1,74 @@
 #!/usr/bin/env python3
 #!/usr/bin/env python3
 import os
 import os
-import sys
 import subprocess
 import subprocess
 
 
+
 def main():
 def main():
     """
     """
     启动 FastAPI 应用,根据环境选择不同参数:
     启动 FastAPI 应用,根据环境选择不同参数:
-    - DEV: 热重载,单进程
+    - DEV: 热重载,单进程,仅监听 app 目录
     - PROD: 多进程,高性能 uvloop + httptools
     - PROD: 多进程,高性能 uvloop + httptools
     """
     """
-    env = os.getenv("ENV", "DEV").upper()  # 默认开发环境
+    
+    os.environ.setdefault(
+        "WATCHFILES_IGNORE",
+        "**/.git/**,**/venv/**,**/__pycache__/**"
+    )
+    
+    env = os.getenv("ENV", "DEV").upper()
+    
+    env = "DEV"
+
     host = "0.0.0.0"
     host = "0.0.0.0"
     port = "8888"
     port = "8888"
     app_module = "app.main:app"
     app_module = "app.main:app"
 
 
-    base_cmd = ["uvicorn", app_module, "--host", host, "--port", port]
+    base_cmd = [
+        "uvicorn",
+        app_module,
+        "--host", host,
+        "--port", port,
+    ]
 
 
     if env == "DEV":
     if env == "DEV":
-        print("启动开发环境(热重载)...")
-        base_cmd += ["--reload", "--workers", "1"]
+        print("🚀 启动开发环境(热重载)")
+
+        base_cmd += [
+            "--reload",
+            "--workers", "1",
+
+            # ⭐ 关键:限制监听范围
+            "--reload-dir", "app",
+            "--reload-exclude", ".git",
+            "--reload-exclude", "venv",
+            "--reload-exclude", "__pycache__",
+        ]
+
     elif env == "PROD":
     elif env == "PROD":
-        print("启动生产环境(多进程 + 高性能)...")
-        base_cmd += ["--workers", str(os.cpu_count()), "--loop", "uvloop", "--http", "httptools"]
+        print("🔥 启动生产环境(多进程 + 高性能)")
+
+        base_cmd += [
+            "--workers", str(os.cpu_count() or 1),
+            "--loop", "uvloop",
+            "--http", "httptools",
+        ]
+
     else:
     else:
-        print(f"未知环境 {env},使用默认开发配置")
-        base_cmd += ["--reload", "--workers", "1"]
+        print(f"⚠️ 未知环境 {env},使用默认 DEV 配置")
+
+        base_cmd += [
+            "--reload",
+            "--workers", "1",
+            "--reload-dir", "app",
+        ]
+
+    print("\n执行命令:")
+    print(" ", " ".join(base_cmd))
+    print(f"\nSwagger UI: http://{host}:{port}/docs")
+    print(f"ReDoc UI:   http://{host}:{port}/redoc\n")
 
 
-    print("执行命令:", " ".join(base_cmd))
-    print(f"Swagger UI: http://{host}:{port}/docs")
-    print(f"ReDoc UI:   http://{host}:{port}/redoc")
     subprocess.run(base_cmd)
     subprocess.run(base_cmd)
 
 
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     main()
     main()

Некоторые файлы не были показаны из-за большого количества измененных файлов