jerry 4 ヶ月 前
コミット
882a5ce465
53 ファイル変更3770 行追加2071 行削除
  1. 1 2
      .env
  2. 263 193
      app/api/router.py
  3. 25 18
      app/core/auth.py
  4. 35 8
      app/core/config.py
  5. 26 25
      app/core/database.py
  6. 7 6
      app/core/redis.py
  7. 45 13
      app/main.py
  8. 0 46
      app/models/auto_booking.py
  9. 24 5
      app/models/payment.py
  10. 18 0
      app/models/payment_confirmation.py
  11. 1 0
      app/models/product_routing.py
  12. 1 1
      app/models/session.py
  13. 2 4
      app/models/verification_token.py
  14. 1 2
      app/schemas/auth.py
  15. 0 52
      app/schemas/auto_booking.py
  16. 4 4
      app/schemas/payment.py
  17. 30 0
      app/schemas/payment_confirmation.py
  18. 2 1
      app/schemas/product_routing.py
  19. 191 133
      app/services/auth_service.py
  20. 0 72
      app/services/auto_booking_service.py
  21. 37 16
      app/services/card_service.py
  22. 85 24
      app/services/configuration_service.py
  23. 469 441
      app/services/email_authorizations_service.py
  24. 57 21
      app/services/http_session_service.py
  25. 18 8
      app/services/notification_service.py
  26. 128 126
      app/services/order_service.py
  27. 103 55
      app/services/payment_provider_service.py
  28. 130 31
      app/services/payment_qr_service.py
  29. 384 78
      app/services/payment_service.py
  30. 37 12
      app/services/product_routing_service.py
  31. 63 30
      app/services/product_service.py
  32. 54 22
      app/services/schema_service.py
  33. 86 45
      app/services/seaweedfs_service.py
  34. 26 20
      app/services/session_service.py
  35. 45 18
      app/services/short_url_service.py
  36. 30 7
      app/services/slot_snapshot_service.py
  37. 52 22
      app/services/sms_service.py
  38. 157 90
      app/services/statistics_service.py
  39. 52 18
      app/services/task_service.py
  40. 17 12
      app/services/telegram_service.py
  41. 238 78
      app/services/ticket_service.py
  42. 125 56
      app/services/troov_service.py
  43. 70 34
      app/services/user_service.py
  44. 103 53
      app/services/vas_task_service.py
  45. 113 116
      app/services/webhook_service.py
  46. 31 12
      app/services/wechat_service.py
  47. 256 0
      app/tasks/notification_task.py
  48. 18 9
      app/utils/pagination.py
  49. 28 14
      app/utils/redis_utils.py
  50. 18 4
      app/utils/search.py
  51. 1 1
      app/utils/validation_utils.py
  52. 11 0
      app/utils/wrappers.py
  53. 52 13
      starter.py

+ 1 - 2
.env

@@ -1,5 +1,4 @@
-DATABASE_URL=mysql://root:GqLLL7Bofj0WaaOpp.0@visafly.top:3306/book_user_info?charset=utf8mb4
+DATABASE_URL=mysql+asyncmy://root:GqLLL7Bofj0WaaOpp.0@visafly.top:3306/book_user_info?charset=utf8mb4
 REDIS_URL=redis://:STEs2x6ML0U1HlpE9SojM6YU7QPhqzY8@45.137.220.138:6379/0
-API_TOKEN=7x9EjFpmv7GjZc6AfVeqxuUBANpqkpkHAtxJM7CAW5oZhs0nEyCJBy39N4XXs5hgfYWXw3jFrcgXqQ42HAx9Qvwtk9vC2GvKBbWz
 OPENAI_API_KEY=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
 STRIPE_API_KEY=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

ファイルの差分が大きいため隠しています
+ 263 - 193
app/api/router.py


+ 25 - 18
app/core/auth.py

@@ -1,40 +1,47 @@
 from enum import IntEnum
-from fastapi import Depends, HTTPException, status
+from fastapi import Depends
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
-from app.core.config import settings
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
+
 from app.core.database import get_db
+from app.core.biz_exception import PermissionDeniedError
 from app.services.session_service import SessionService
 
-
-security = HTTPBearer()
+security = HTTPBearer(auto_error=False)
 
 
 class RoleLevel(IntEnum):
     user = 10
     admin = 100
 
+
 ROLE_LEVEL_MAP = {
     "user": 10,
     "admin": 100,
 }
 
-def require_min_role(min_role: RoleLevel):
-    def checker(user=Depends(get_current_user)):
-        current_level = ROLE_LEVEL_MAP.get(user.role, 0)
-        if current_level < min_role:
-            raise PermissionDeniedError("Permission denied")
-        return user
-    return checker
 
-
-def get_current_user(
+async def get_current_user(
     credentials: HTTPAuthorizationCredentials = Depends(security),
-    db: Session = Depends(get_db)
+    db: AsyncSession = Depends(get_db),
 ):
+    if not credentials:
+        raise PermissionDeniedError("Missing token")
+
     token = credentials.credentials
-    user = SessionService().get_user_by_token(db, token)
+    user = await SessionService().get_user_by_token(db, token)
+
     if not user:
         raise PermissionDeniedError("Invalid or expired token")
-    return user
+
+    return user
+
+
+def require_min_role(min_role: RoleLevel):
+    async def checker(user=Depends(get_current_user)):
+        current_level = ROLE_LEVEL_MAP.get(user.role, 0)
+        if current_level < min_role:
+            raise PermissionDeniedError("Permission denied")
+        return user
+
+    return checker

+ 35 - 8
app/core/config.py

@@ -1,18 +1,45 @@
-from pydantic_settings import BaseSettings
+from functools import lru_cache
+from pydantic_settings import BaseSettings, SettingsConfigDict
+from pydantic import Field
+
 
 class Settings(BaseSettings):
+    # -----------------------
+    # App
+    # -----------------------
     app_name: str = "MyApp"
     debug: bool = False
-    database_url: str
+
+    # -----------------------
+    # Database / Cache
+    # -----------------------
+    database_url: str = Field(..., description="Async database DSN")
     redis_url: str
-    api_token: str
+
+    # -----------------------
+    # Security / API Keys
+    # -----------------------
     openai_api_key: str
     stripe_api_key: str
-    
-    class Config:
-        env_file = ".env"
 
-settings = Settings()
+    model_config = SettingsConfigDict(
+        env_file=".env",
+        env_file_encoding="utf-8",
+        case_sensitive=False,
+    )
+
+
+@lru_cache
+def get_settings() -> Settings:
+    """
+    避免多次实例化 Settings(FastAPI 官方推荐)
+    """
+    return Settings()
+
 
-base_currency = "EUR"
+settings = get_settings()
 
+# -----------------------
+# Global constants
+# -----------------------
+BASE_CURRENCY = "EUR"

+ 26 - 25
app/core/database.py

@@ -1,25 +1,29 @@
-from sqlalchemy import create_engine
-from sqlalchemy.orm import sessionmaker, declarative_base
+from sqlalchemy.ext.asyncio import (
+    create_async_engine,
+    async_sessionmaker,
+    AsyncSession,
+)
+from sqlalchemy.orm import declarative_base
 from app.core.config import settings
 
 # =========================
-# 数据库初始化
+# Async Engine
 # =========================
-# 建立 Engine
-engine = create_engine(
-    settings.database_url,
-    echo=settings.debug,        # 是否打印 SQL 日志
-    pool_pre_ping=True,         # 检测断开的连接
-    pool_recycle=1800,          # 连接回收时间,防止 MySQL 8 小时断开
-    future=True                 # 启用 SQLAlchemy 2.0 风格
+engine = create_async_engine(
+    settings.database_url,      # ⚠️ 必须是 async URL
+    echo=settings.debug,
+    pool_pre_ping=True,
+    pool_recycle=1800,
 )
 
-# 建立 Session 工厂
-SessionLocal = sessionmaker(
-    autocommit=False,
-    autoflush=False,
+# =========================
+# Async Session 工厂
+# =========================
+AsyncSessionLocal = async_sessionmaker(
     bind=engine,
-    expire_on_commit=False
+    class_=AsyncSession,
+    autoflush=False,
+    expire_on_commit=False,
 )
 
 # ORM 基类
@@ -27,14 +31,11 @@ Base = declarative_base()
 
 
 # =========================
-# 数据库依赖
+# FastAPI 依赖
 # =========================
-def get_db():
-    """
-    FastAPI 的依赖注入函数,用于在请求周期内创建并关闭数据库会话。
-    """
-    db = SessionLocal()
-    try:
-        yield db
-    finally:
-        db.close()
+async def get_db() -> AsyncSession:
+    async with AsyncSessionLocal() as session:
+        try:
+            yield session
+        finally:
+            await session.close()

+ 7 - 6
app/core/redis.py

@@ -1,14 +1,15 @@
+
 from typing import Optional
-from redis import Redis
+from redis.asyncio import Redis
 from app.core.config import settings
 
 _redis_client: Optional[Redis] = None
 
-def get_redis_client() -> Redis:
-    """
-    同步依赖(FastAPI 可以直接注入)
-    """
+async def get_redis_client() -> Redis:
     global _redis_client
     if _redis_client is None:
-        _redis_client = Redis.from_url(settings.redis_url, decode_responses=True)
+        _redis_client = Redis.from_url(
+            settings.redis_url,
+            decode_responses=True
+        )
     return _redis_client

+ 45 - 13
app/main.py

@@ -1,23 +1,48 @@
+import asyncio
+
 from fastapi import FastAPI, Depends, Request
 from fastapi.responses import JSONResponse
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.openapi.utils import get_openapi
-from fastapi.security import HTTPBearer
 
 from app.api import router
+from app.core.redis import get_redis_client
 from app.core.auth import RoleLevel, require_min_role
 from app.core.config import settings
 from app.core.payment import init_stripe
 from app.core.biz_exception import BizException
 from app.core.logger import logger
+from app.tasks.notification_task import notification_consumer
 
 
 app = FastAPI(title=settings.app_name)
 
+# -----------------------
+# Startup
+# -----------------------
 @app.on_event("startup")
-def startup():
+async def startup():
+    # 如果 init_stripe 是 async
     init_stripe()
     
+ 
+# 全局 Redis 客户端
+
+
+@app.on_event("startup")
+async def startup_event():
+    """
+    FastAPI 启动时执行
+    """
+    # 启动后台消费任务
+    redis_client = await get_redis_client()
+    asyncio.create_task(notification_consumer(redis_client))
+    print("🟢 Notification consumer started")   
+
+
+# -----------------------
+# Exception Handlers
+# -----------------------
 @app.exception_handler(BizException)
 async def biz_exception_handler(request: Request, exc: BizException):
     return JSONResponse(
@@ -29,10 +54,10 @@ async def biz_exception_handler(request: Request, exc: BizException):
         },
     )
 
+
 @app.exception_handler(Exception)
 async def unhandled_exception_handler(request: Request, exc: Exception):
-    # ⚠️ 一定要打日志
-    logger.error("Unhandled exception")
+    logger.error("Unhandled exception", exc_info=exc)
 
     return JSONResponse(
         status_code=500,
@@ -43,8 +68,9 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
         },
     )
 
+
 # -----------------------
-# CORS(可选)
+# CORS
 # -----------------------
 app.add_middleware(
     CORSMiddleware,
@@ -53,52 +79,58 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
 # -----------------------
-# 路由注册
+# Routers
 # -----------------------
-# 公共路由,不鉴权
 app.include_router(
     router.public_router,
     prefix="/api"
 )
-# 需要鉴权的路由
+
 app.include_router(
     router.protected_router,
     prefix="/api",
     dependencies=[Depends(require_min_role(RoleLevel.user))]
 )
 
-# 需要管理员权限的路由
 app.include_router(
     router.admin_required_router,
     prefix="/api",
     dependencies=[Depends(require_min_role(RoleLevel.admin))]
 )
 
+
 # -----------------------
-# Swagger 支持 Bearer Token
+# Swagger Bearer Token
 # -----------------------
 def custom_openapi():
     if app.openapi_schema:
         return app.openapi_schema
+
     openapi_schema = get_openapi(
         title=app.title,
         version="1.0.0",
         description="API documentation",
         routes=app.routes,
     )
-    # 添加全局 Bearer
+
+    openapi_schema.setdefault("components", {})
     openapi_schema["components"]["securitySchemes"] = {
         "BearerAuth": {
             "type": "http",
             "scheme": "bearer",
-            "bearerFormat": "JWT"
+            "bearerFormat": "JWT",
         }
     }
+
     for path in openapi_schema["paths"].values():
         for method in path.values():
-            method["security"] = [{"BearerAuth": []}]
+            method.setdefault("security", [])
+            method["security"].append({"BearerAuth": []})
+
     app.openapi_schema = openapi_schema
     return app.openapi_schema
 
+
 app.openapi = custom_openapi

+ 0 - 46
app/models/auto_booking.py

@@ -1,46 +0,0 @@
-from sqlalchemy import Column, BigInteger, String, Integer, Text, Date, DateTime
-from app.core.database import Base
-
-class AutoBooking(Base):
-    __tablename__ = "auto_booking"
-
-    id = Column(BigInteger, primary_key=True, autoincrement=True)
-    provider = Column(String(100))
-    visa_center = Column(String(100))
-    order_no = Column(String(100))
-    social_account = Column(String(100))
-    account = Column(String(100))
-    password = Column(String(100))
-    last_name = Column(String(100))
-    first_name = Column(String(100))
-    gender = Column(String(10))
-    birthday = Column(Date)
-    email = Column(String(150))
-    alias_email = Column(String(150))
-    phone_country_code = Column(String(20))
-    phone_no = Column(String(50))
-    passport_no = Column(String(50))
-    nationality = Column(String(50))
-    passport_expiry_date = Column(Date)
-    address_line1 = Column(Text)
-    address_line2 = Column(Text)
-    state = Column(String(100))
-    city = Column(String(100))
-    postcode = Column(String(100))
-    travel_date = Column(Date)
-    cover_letter = Column(String(100))
-    passport_image_url = Column(Text)
-    selfie_image_url = Column(Text)
-    application_form_url = Column(Text)
-    priority = Column(Integer, default=0)
-    expected_submit_start = Column(Date)
-    expected_submit_end = Column(Date)
-    rules = Column(Text)
-    status = Column(Integer, default=0)
-    placeholder = Column(Integer, default=0)
-    appointment_datetime = Column(DateTime)
-    appointment_letter_url = Column(Text)
-    pnr_number = Column(String(100))
-    payment_link = Column(Text)
-    payment_help = Column(Integer, default=0)
-    note = Column(Text)

+ 24 - 5
app/models/payment.py

@@ -2,21 +2,31 @@ from sqlalchemy import Column, Integer, String, Text, DateTime, Enum, JSON, DECI
 from datetime import datetime
 from app.core.database import Base
 
-
 class VasPayment(Base):
     __tablename__ = "vas_payment"
 
     id = Column(Integer, primary_key=True, autoincrement=True)
     order_id = Column(String(128), nullable=False)
 
-    provider = Column(Enum('stripe','wechat','alipay'), nullable=False)
+    provider = Column(Enum('stripe', 'wechat', 'alipay'), nullable=False)
     channel = Column(String(50), nullable=False)
+    
+    # 支付意向ID (Stripe PaymentIntent ID)
     payment_intent_id = Column(String(255))
+    # 外部交易号 (微信/支付宝的 transaction_id)
     external_trade_no = Column(String(255))
 
+    # --- 修改点 1: 扩充状态枚举 ---
     status = Column(
-        Enum('pending','succeeded','failed','expired', 'late_paid'),
-        default='pending'
+        Enum(
+            'pending',              # 待支付
+            'succeeded',            # 支付成功
+            'failed',               # 支付失败
+            'expired',              # 支付超时
+            'late_paid',            # 逾期支付(极少见)
+            'refunded',             # 已全额退款
+        ),
+        default='pending',
     )
     
     base_amount = Column(Integer, nullable=False)
@@ -33,6 +43,15 @@ class VasPayment(Base):
     expire_at = Column(DateTime)
 
     provider_payload = Column(JSON)
+    
+    # 外部退款单号 (Stripe Refund ID / 微信退款单号 / 支付宝退款单号)
+    external_refund_no = Column(String(255))
+    
+    # 退款时间
+    refunded_at = Column(DateTime)
+    
+    # 退款原因/备注
+    refund_reason = Column(String(255))
 
     created_at = Column(DateTime, default=datetime.utcnow)
-    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
+    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

+ 18 - 0
app/models/payment_confirmation.py

@@ -0,0 +1,18 @@
+from sqlalchemy import Column, BigInteger, Integer, String, Enum, DateTime, func
+from app.core.database import Base
+
+
+class VasPaymentConfirmation(Base):
+    __tablename__ = "vas_payment_confirm"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True)
+    payment_id = Column(BigInteger, nullable=False)
+    amount = Column(Integer, nullable=False)
+    currency = Column(String(10), nullable=False)
+    random_offset = Column(Integer, nullable=False)
+    user_id = Column(String(128), nullable=False)
+    status = Column(Enum("pending","confirmed","ignored"), nullable=False, default="pending")
+    created_at = Column(DateTime, nullable=False, server_default=func.now())
+    confirmed_at = Column(DateTime, nullable=True)
+    admin_id = Column(String(128), nullable=True)
+    admin_confirmed_at = Column(DateTime, nullable=True)

+ 1 - 0
app/models/product_routing.py

@@ -12,6 +12,7 @@ class VasProductRouting(Base):
     routing_key = Column(String(255), nullable=False)
     script_version = Column(String(50), nullable=False)
     is_active = Column(Integer, default=1)
+    priority = Column(Integer, default=10)
 
     config = Column(JSON)
 

+ 1 - 1
app/models/session.py

@@ -8,7 +8,7 @@ class VasSession(Base):
 
     id = Column(String(128), primary_key=True)   # token
     user_id = Column(String(128), nullable=False)
-    user_agent = Column(String(128), nullable=False)
+    user_agent = Column(String(255), nullable=False)
     ip = Column(String(128), nullable=False)
     expire_at = Column(DateTime, nullable=False)
 

+ 2 - 4
app/models/email_verification.py → app/models/verification_token.py

@@ -3,12 +3,10 @@ from datetime import datetime
 from app.core.database import Base
 
 
-class VasEmailVerification(Base):
-    __tablename__ = "vas_email_verification"
+class VasVerificationToken(Base):
+    __tablename__ = "vas_verification_token"
 
     id = Column(Integer, primary_key=True, autoincrement=True)
-    user_id = Column(String(64), nullable=False)
-    email = Column(String(255), nullable=False)
     token = Column(String(128), nullable=False)
     used = Column(Integer, default=0)
     expire_at = Column(DateTime, nullable=False)

+ 1 - 2
app/schemas/auth.py

@@ -4,8 +4,7 @@ from app.schemas.common import ApiResponse
 from app.schemas.user import VasUserOut
 
 class AutoRegisterRequest(BaseModel):
-    user_agent: Optional[str] = None
-    register_ip: str
+    pass
 
 class AutoRegisterData(BaseModel):
     user: VasUserOut

+ 0 - 52
app/schemas/auto_booking.py

@@ -1,52 +0,0 @@
-from pydantic import BaseModel
-from typing import Optional
-from datetime import date, datetime
-
-class AutoBookingBase(BaseModel):
-    provider: Optional[str] = None
-    visa_center: Optional[str] = None
-    order_no: Optional[str] = None
-    social_account: Optional[str] = None
-    account: Optional[str] = None
-    password: Optional[str] = None
-    last_name: Optional[str] = None
-    first_name: Optional[str] = None
-    gender: Optional[str] = None
-    birthday: Optional[date] = None
-    email: Optional[str] = None
-    alias_email: Optional[str] = None
-    phone_country_code: Optional[str] = None
-    phone_no: Optional[str] = None
-    passport_no: Optional[str] = None
-    nationality: Optional[str] = None
-    passport_expiry_date: Optional[date] = None
-    address_line1: Optional[str] = None
-    address_line2: Optional[str] = None
-    state: Optional[str] = None
-    city: Optional[str] = None
-    postcode: Optional[str] = None
-    travel_date: Optional[date] = None
-    cover_letter: Optional[str] = None
-    passport_image_url: Optional[str] = None
-    selfie_image_url: Optional[str] = None
-    application_form_url: Optional[str] = None
-    priority: Optional[int] = None
-    expected_submit_start: Optional[date] = None
-    expected_submit_end: Optional[date] = None
-    rules: Optional[str] = None
-    status: Optional[int] = None
-    placeholder: Optional[int] = None
-    appointment_datetime: Optional[datetime] = None
-    appointment_letter_url: Optional[str] = None
-    pnr_number: Optional[str] = None
-    payment_link: Optional[str] = None
-    payment_help: Optional[int] = None
-    note: Optional[str] = None
-
-class AutoBookingCreate(AutoBookingBase):
-    pass
-
-class AutoBookingOut(AutoBookingBase):
-    id: int
-    class Config:
-        orm_mode = True

+ 4 - 4
app/schemas/payment.py

@@ -5,7 +5,7 @@ from datetime import datetime
 
 
 class VasPaymentBase(BaseModel):
-    status: Optional[Literal['pending', 'succeeded', 'failed', 'expired', 'late_paid']] = None
+    status: Optional[Literal['pending', 'succeeded', 'failed', 'expired', 'late_paid', 'refunded']] = None
 
     qr_id: Optional[int] = None
     payment_url: Optional[str] = None
@@ -28,9 +28,6 @@ class VasPaymentCreate(BaseModel):
     order_id: str
     provider: Literal['stripe', 'wechat', 'alipay']
 
-class VasPaymentUpdate(VasPaymentBase):
-    pass
-
 class VasPaymentOut(VasPaymentBase):
     id: int
     order_id: str
@@ -50,6 +47,9 @@ class VasPaymentOut(VasPaymentBase):
     
     exchange_rate: float  # 注意:仅用于展示,DB 里是 DECIMAL
 
+    external_refund_no: Optional[str]
+    refund_reason: Optional[str]
+    refunded_at: Optional[datetime]
     created_at: datetime
     updated_at: datetime
 

+ 30 - 0
app/schemas/payment_confirmation.py

@@ -0,0 +1,30 @@
+from pydantic import BaseModel
+from typing import Optional, Dict, Any, Literal, List
+from datetime import datetime
+
+class VasPaymentConfirmationBase(BaseModel):
+    payment_id: int
+    amount: int
+    currency: str
+    random_offset: int
+    confirmed_at: datetime
+
+class VasPaymentConfirmationCreate(VasPaymentConfirmationBase):
+    pass
+
+class VasPaymentConfirmationUpdate(BaseModel):
+    status: Optional[Literal['confirmed', 'ignored']] = None
+    admin_id: Optional[str] = None
+    admin_confirmed_at: Optional[datetime] = None
+
+class VasPaymentConfirmationOut(VasPaymentConfirmationBase):
+    id: int
+    user_id: str
+    status: str
+    created_at: datetime
+    admin_id: Optional[str] = None
+    admin_confirmed_at: Optional[datetime] = None
+
+    model_config = {
+        "from_attributes": True
+    }

+ 2 - 1
app/schemas/product_routing.py

@@ -22,6 +22,7 @@ class VasProductRoutingCreate(BaseModel):
     product_id: int
     routing_key: str
     script_version: str
+    priority: int
     config: Dict[str, Any]
 
 class VasProductRoutingUpdate(VasProductRoutingBase):
@@ -30,7 +31,7 @@ class VasProductRoutingUpdate(VasProductRoutingBase):
 class VasProductRoutingOut(VasProductRoutingBase):
     id: int
     product_id: int
-
+    priority: int
     routing_key: str
     script_version: str
 

+ 191 - 133
app/services/auth_service.py

@@ -1,27 +1,56 @@
-import uuid, bcrypt, random, string
-from sqlalchemy.orm import Session
+# app/services/auth_service.py
+
+import uuid
+import bcrypt
+import random
+import string
 from datetime import datetime, timedelta
+from typing import Dict
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
 from redis.asyncio import Redis
-from app.utils.redis_utils import redis_qpush
-from app.core.auth import get_current_user
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from app.core.biz_exception import (
+    NotFoundError,
+    PermissionDeniedError,
+    BizLogicError,
+)
 from app.models.user import VasUser
 from app.models.session import VasSession
-from app.models.email_verification import VasEmailVerification
-from app.schemas.auth import AutoRegisterRequest, SendBindCodeRequest, SendResetCodeRequest, BindEmailRequest, ResetPasswordRequest, LoginRequest
+from app.models.verification_token import VasVerificationToken
+from app.schemas.auth import (
+    AutoRegisterRequest,
+    SendBindCodeRequest,
+    SendResetCodeRequest,
+    BindEmailRequest,
+    ResetPasswordRequest,
+    LoginRequest,
+)
 from app.services.notification_service import NotificationService
 
 
-def _random_password(length=16):
-    return ''.join(random.choices(string.ascii_letters + string.digits + "!@#$%", k=length))
+def _random_password(length: int = 16) -> str:
+    return "".join(
+        random.choices(
+            string.ascii_letters + string.digits + "!@#$%",
+            k=length,
+        )
+    )
+
 
 class AuthService:
-    # -----------------------
-    # 自动注册
-    # -----------------------
+    # =========================
+    # 自动注册(游客)
+    # =========================
     @staticmethod
-    def auto_register(db: Session, req:AutoRegisterRequest):
-        uid = f'usr-{uuid.uuid4().hex[:8]}'
+    async def auto_register(
+        db: AsyncSession,
+        req: AutoRegisterRequest,
+        ip: str = None,
+        user_agent = None
+    ) -> Dict:
+        uid = f"usr-{uuid.uuid4().hex[:8]}"
 
         user = VasUser(
             id=uid,
@@ -29,197 +58,226 @@ class AuthService:
             nickname="anonymous visitor",
             preferred_language="en",
             timezone="Asia/Shanghai",
-            register_ip=req.register_ip,
+            register_ip=ip or '',
         )
         db.add(user)
-        db.commit()
-
-        # 创建 session
-        token = f"tok_{uuid.uuid4().hex}"
 
+        token = "tok_" + uuid.uuid4().hex
         session = VasSession(
             id=token,
             user_id=uid,
-            user_agent=req.user_agent or "",
-            ip=req.register_ip,
-            expire_at=datetime.utcnow() + timedelta(days=7)
+            user_agent=user_agent or '',
+            ip=ip or '',
+            expire_at=datetime.utcnow() + timedelta(days=7),
         )
         db.add(session)
-        db.commit()
-        return {
-            "user": user,
-            "token": token
-        }
-        
-    def send_bind_code(db: Session, payload: SendBindCodeRequest, auth_user: VasUser, redis_client:Redis):
-        token = uuid.uuid4().hex[0:6]
-        record = VasEmailVerification(
-            user_id=auth_user.id,
-            email=payload.email,
+
+        await db.commit()
+        await db.refresh(user)
+
+        return {"user": user, "token": token}
+
+    # =========================
+    # 发送绑定邮箱验证码
+    # =========================
+    @staticmethod
+    async def send_bind_code(
+        db: AsyncSession,
+        payload: SendBindCodeRequest,
+        auth_user: VasUser,
+        redis_client: Redis,
+    ):
+        token = uuid.uuid4().hex[:6]
+
+        record = VasVerificationToken(
             token=token,
-            expire_at=datetime.utcnow() + timedelta(minutes=30)
+            expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         db.add(record)
-        db.commit()
-        
-        print(f"📧 send verification email token={token}")
-        NotificationService.create(
+        await db.commit()
+
+        await NotificationService.create(
             redis_client=redis_client,
-            ntype="email verification email",
+            ntype="email_verification",
             user_id=auth_user.id,
             channels=["email"],
             template_id="email_verification_for_bind",
-            payload={
-                "token": token
-            }
+            payload={"token": token},
         )
-        
-    def send_reset_code(db: Session, payload: SendResetCodeRequest, redis_client:Redis):
-        user = db.query(VasUser).filter(
+
+    # =========================
+    # 发送重置密码验证码
+    # =========================
+    @staticmethod
+    async def send_reset_code(
+        db: AsyncSession,
+        payload: SendResetCodeRequest,
+        redis_client: Redis,
+    ):
+        stmt = select(VasUser).where(
             VasUser.email == payload.email,
-            VasUser.email_verified == 1
-        ).first()
+            VasUser.email_verified == 1,
+        )
+        user = (await db.execute(stmt)).scalar_one_or_none()
+
         if not user:
             raise BizLogicError("User not exist")
-        
-        token = uuid.uuid4().hex[0:6]
-        record = VasEmailVerification(
-            user_id=user.id,
-            email=payload.email,
+
+        token = uuid.uuid4().hex[:6]
+        record = VasVerificationToken(
             token=token,
-            expire_at=datetime.utcnow() + timedelta(minutes=30)
+            expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         db.add(record)
-        db.commit()
-        
-        print(f"📧 send verification email token={token}")
-        NotificationService.create(
+        await db.commit()
+
+        await NotificationService.create(
             redis_client=redis_client,
-            ntype="email verification email",
+            ntype="email_verification",
             user_id=user.id,
             channels=["email"],
             template_id="email_verification_for_reset",
-            payload={
-                "token": token
-            }
+            payload={"token": token},
         )
-    # -----------------------
+
+    # =========================
     # 绑定邮箱
-    # -----------------------
+    # =========================
     @staticmethod
-    def bind_email(db: Session, payload: BindEmailRequest, auth_user: VasUser, redis_client:Redis):
-        user = db.query(VasUser).filter(
+    async def bind_email(
+        db: AsyncSession,
+        payload: BindEmailRequest,
+        auth_user: VasUser,
+        redis_client: Redis,
+        ip: str = None,
+        user_agent = None
+    ) -> Dict:
+        # 邮箱是否已被绑定
+        stmt = select(VasUser).where(
             VasUser.email == payload.email,
-            VasUser.email_verified == 1
-        ).first()
-        if user:
-            raise BizLogicError("Email has been bound")
-        
-        record = (
-            db.query(VasEmailVerification)
-            .filter_by(token=payload.code, used=0)
-            .first()
+            VasUser.email_verified == 1,
+        )
+        if (await db.execute(stmt)).scalar_one_or_none():
+            raise BizLogicError("Email already bound")
+
+        # 校验验证码
+        stmt = select(VasVerificationToken).where(
+            VasVerificationToken.token == payload.code,
+            VasVerificationToken.used == 0,
         )
+        record = (await db.execute(stmt)).scalar_one_or_none()
         if not record:
             raise BizLogicError("Token invalid")
 
         if record.expire_at < datetime.utcnow():
             raise BizLogicError("Token expired")
- 
-        # 更新 user.email
-        user = db.query(VasUser).filter_by(id=record.user_id).first()
-        user.email = payload.email
 
-        # 随机密码
-        plain = _random_password()
-        hashed = bcrypt.hashpw(plain.encode(), bcrypt.gensalt()).decode()
-        user.password_hash = hashed
+        user = await db.get(VasUser, auth_user.id)
+
+        plain_pwd = _random_password()
+        hashed_pwd = bcrypt.hashpw(
+            plain_pwd.encode(),
+            bcrypt.gensalt(),
+        ).decode()
+
+        user.email = payload.email
+        user.password_hash = hashed_pwd
         user.email_verified = 1
         record.used = 1
-        
-        # 创建 session
-        session_id = "tok_" + uuid.uuid4().hex
 
+        token = "tok_" + uuid.uuid4().hex
         session = VasSession(
-            id=session_id,
+            id=token,
             user_id=user.id,
-            user_agent="",
-            ip="",
-            expire_at=datetime.utcnow() + timedelta(days=30)
+            ip=ip or '',
+            user_agent=user_agent or '',
+            expire_at=datetime.utcnow() + timedelta(days=30),
         )
         db.add(session)
-        db.commit()
-        db.refresh(user)
-        
-        print(f"📧 send login email and password")
-        NotificationService.create(
+
+        await db.commit()
+        await db.refresh(user)
+
+        await NotificationService.create(
             redis_client=redis_client,
-            ntype="login credentials",
+            ntype="login_credentials",
             user_id=user.id,
             channels=["email"],
             template_id="login_credentials",
             payload={
                 "username": payload.email,
-                "password": plain
-            }
+                "password": plain_pwd,
+            },
         )
-        
-        return {
-            "user": user,
-            "token": session_id
-        }
-    
-    def reset_password(db: Session, payload: ResetPasswordRequest):
-        user = db.query(VasUser).filter(
+
+        return {"user": user, "token": token}
+
+    # =========================
+    # 重置密码
+    # =========================
+    @staticmethod
+    async def reset_password(
+        db: AsyncSession,
+        payload: ResetPasswordRequest,
+    ) -> bool:
+        stmt = select(VasUser).where(
             VasUser.email == payload.email,
-            VasUser.email_verified == 1
-        ).first()
+            VasUser.email_verified == 1,
+        )
+        user = (await db.execute(stmt)).scalar_one_or_none()
         if not user:
             raise BizLogicError("User not exist")
-        
-        record = (
-            db.query(VasEmailVerification)
-            .filter_by(token=payload.code, used=0)
-            .first()
+
+        stmt = select(VasVerificationToken).where(
+            VasVerificationToken.token == payload.code,
+            VasVerificationToken.used == 0,
         )
+        record = (await db.execute(stmt)).scalar_one_or_none()
         if not record:
             raise BizLogicError("Token invalid")
 
         if record.expire_at < datetime.utcnow():
             raise BizLogicError("Token expired")
-        
-        hashed = bcrypt.hashpw(payload.new_password.encode(), bcrypt.gensalt()).decode()
-        user.password_hash = hashed
+
+        user.password_hash = bcrypt.hashpw(
+            payload.new_password.encode(),
+            bcrypt.gensalt(),
+        ).decode()
         record.used = 1
-        db.commit()
+
+        await db.commit()
         return True
-    # -----------------------
-    # 用户登录
-    # -----------------------
+
+    # =========================
+    # 登录
+    # =========================
     @staticmethod
-    def login(db: Session, req:LoginRequest):
-        user = db.query(VasUser).filter_by(email=req.email).first()
+    async def login(
+        db: AsyncSession,
+        req: LoginRequest,
+        ip: str = None,
+        user_agent: str = None
+    ) -> Dict:
+        stmt = select(VasUser).where(VasUser.email == req.email)
+        user = (await db.execute(stmt)).scalar_one_or_none()
         if not user:
             raise NotFoundError("User not found")
 
-        # 对比密码
-        if not bcrypt.checkpw(req.password.encode(), user.password_hash.encode()):
+        if not bcrypt.checkpw(
+            req.password.encode(),
+            user.password_hash.encode(),
+        ):
             raise PermissionDeniedError("Password incorrect")
 
-        # 创建 session
         token = "tok_" + uuid.uuid4().hex
-
         session = VasSession(
             id=token,
             user_id=user.id,
-            user_agent="",
-            ip="",
-            expire_at=datetime.utcnow() + timedelta(days=7)
+            user_agent=user_agent or "",
+            ip=ip or "",
+            expire_at=datetime.utcnow() + timedelta(days=7),
         )
         db.add(session)
-        db.commit()
+        await db.commit()
 
-        return {
-            "user": user,
-            "token": token
-        }
+        return {"user": user, "token": token}

+ 0 - 72
app/services/auto_booking_service.py

@@ -1,72 +0,0 @@
-from sqlalchemy.orm import Session
-from sqlalchemy import func
-from app.models.auto_booking import AutoBooking
-from app.schemas.auto_booking import AutoBookingCreate
-from typing import List
-
-class AutoBookingService:
-
-    @staticmethod
-    def create(db: Session, obj_in: AutoBookingCreate) -> AutoBooking:
-        db_obj = AutoBooking(**obj_in.dict())
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
-        return db_obj
-
-    @staticmethod
-    def get_by_id(db: Session, id: int):
-        return db.query(AutoBooking).filter(AutoBooking.id == id).first()
-
-    @staticmethod
-    def delete_by_id(db: Session, id: int):
-        obj = db.query(AutoBooking).filter(AutoBooking.id == id).first()
-        if obj:
-            db.delete(obj)
-            db.commit()
-            return True
-        return False
-
-    @staticmethod
-    def update_by_id(db: Session, id: int, updated_data: dict):
-        obj = db.query(AutoBooking).filter(AutoBooking.id == id).first()
-        if not obj:
-            return None
-        for key, value in updated_data.items():
-            setattr(obj, key, value)
-        db.commit()
-        db.refresh(obj)
-        return obj
-
-    @staticmethod
-    def get_paginated(db: Session, tech_provider: str, keyword: str, page: int, size: int):
-        query = db.query(AutoBooking)
-        if tech_provider:
-            query = query.filter(AutoBooking.provider == tech_provider)
-        if keyword:
-            like_str = f"%{keyword}%"
-            query = query.filter(
-                (AutoBooking.first_name.like(like_str)) |
-                (AutoBooking.last_name.like(like_str)) |
-                (AutoBooking.email.like(like_str)) |
-                (AutoBooking.visa_center.like(like_str))
-            )
-        return query.offset(page * size).limit(size).all()
-
-    @staticmethod
-    def batch_get_by_ids(db: Session, ids: List[int]):
-        return db.query(AutoBooking).filter(AutoBooking.id.in_(ids)).all()
-
-    @staticmethod
-    def statistics(db: Session, tech_provider: str):
-        query = db.query(AutoBooking.provider, func.count(AutoBooking.id)).group_by(AutoBooking.provider)
-        if tech_provider:
-            query = query.filter(AutoBooking.provider == tech_provider)
-        return query.all()
-
-    @staticmethod
-    def get_pending(db: Session, tech_provider: str):
-        query = db.query(AutoBooking).filter(AutoBooking.status == 0)
-        if tech_provider:
-            query = query.filter(AutoBooking.provider == tech_provider)
-        return query.all()

+ 37 - 16
app/services/card_service.py

@@ -1,30 +1,51 @@
-from sqlalchemy.orm import Session
-from sqlalchemy import text
-from typing import List, Optional
-from app.utils.search import apply_keyword_search
-from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+# app/services/card_service.py
+
+from typing import Optional
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
 from app.models.card import Card
 from app.schemas.card import CardCreate
+from app.utils.search import apply_keyword_search_stmt
+from app.utils.pagination import paginate
+from app.core.biz_exception import BizLogicError
 
 
 class CardService:
+
+    # =========================
+    # 创建 Card
+    # =========================
     @staticmethod
-    def create(db: Session, obj_in: CardCreate) -> Card:
+    async def create(
+        db: AsyncSession,
+        obj_in: CardCreate,
+    ) -> Card:
         db_obj = Card(**obj_in.dict())
         db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
 
+    # =========================
+    # 关键字分页查询
+    # =========================
     @staticmethod
-    def list_by_keyword(db: Session, keyword: str = None, page: int = 0, size: int = 10, culture: str = "english"):
-        query = db.query(Card).filter(Card.culture == culture)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_by_keyword(
+        db: AsyncSession,
+        keyword: Optional[str] = None,
+        page: int = 0,
+        size: int = 10,
+        culture: str = "english",
+    ):
+        stmt = select(Card).where(Card.culture == culture)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=Card,
             keyword=keyword,
-            fields=["title", "content", "label"]
+            fields=["title", "content", "label"],
         )
-        return paginate(query, page, size)
+
+        return await paginate(db, stmt, page, size)

+ 85 - 24
app/services/configuration_service.py

@@ -1,50 +1,111 @@
-from sqlalchemy.orm import Session
+# app/services/configuration_service.py
+
 from typing import List, Optional
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.configuration import Configuration
 from app.schemas.configuration import ConfigurationCreate, ConfigurationUpdate
 
 
 class ConfigurationService:
+
+    # =========================
+    # 创建配置
+    # =========================
     @staticmethod
-    def create(db: Session, config_in: ConfigurationCreate) -> Configuration:
-        config = db.query(Configuration).filter(Configuration.config_key == config_in.config_key).first()
-        if config:
-            raise BizLogicError(f"Config Key '{config_in.config_key}' already exist")
+    async def create(
+        db: AsyncSession,
+        config_in: ConfigurationCreate,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_in.config_key
+        )
+        existing = (await db.execute(stmt)).scalar_one_or_none()
+        if existing:
+            raise BizLogicError(
+                f"Config Key '{config_in.config_key}' already exist"
+            )
+
         db_obj = Configuration(**config_in.dict())
         db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
 
+    # =========================
+    # 获取全部配置
+    # =========================
     @staticmethod
-    def get_all(db: Session) -> List[Configuration]:
-        return db.query(Configuration).order_by(Configuration.id.desc()).all()
+    async def get_all(
+        db: AsyncSession,
+    ) -> List[Configuration]:
+        stmt = select(Configuration).order_by(Configuration.id.desc())
+        result = await db.execute(stmt)
+        return result.scalars().all()
 
+    # =========================
+    # 根据 key 获取配置
+    # =========================
     @staticmethod
-    def get_by_key(db: Session, config_key: str) -> Optional[Configuration]:
-        config = db.query(Configuration).filter(Configuration.config_key == config_key).first()
+    async def get_by_key(
+        db: AsyncSession,
+        config_key: str,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_key
+        )
+        config = (await db.execute(stmt)).scalar_one_or_none()
         if not config:
-            raise NotFoundError(f"Config Key '{config_key}' not exist")
+            raise NotFoundError(
+                f"Config Key '{config_key}' not exist"
+            )
         return config
 
+    # =========================
+    # 根据 key 更新配置
+    # =========================
     @staticmethod
-    def update_by_key(db: Session, config_key: str, config_in: ConfigurationUpdate) -> Optional[Configuration]:
-        db_obj = db.query(Configuration).filter(Configuration.config_key == config_key).first()
+    async def update_by_key(
+        db: AsyncSession,
+        config_key: str,
+        config_in: ConfigurationUpdate,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_key
+        )
+        db_obj = (await db.execute(stmt)).scalar_one_or_none()
         if not db_obj:
-            raise NotFoundError(f"Config Key '{config_key}' not exist")
+            raise NotFoundError(
+                f"Config Key '{config_key}' not exist"
+            )
+
         for field, value in config_in.dict(exclude_unset=True).items():
             setattr(db_obj, field, value)
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
 
+    # =========================
+    # 根据 key 删除配置
+    # =========================
     @staticmethod
-    def delete_by_key(db: Session, config_key: str) -> Optional[Configuration]:
-        db_obj = db.query(Configuration).filter(Configuration.config_key == config_key).first()
+    async def delete_by_key(
+        db: AsyncSession,
+        config_key: str,
+    ) -> Configuration:
+        stmt = select(Configuration).where(
+            Configuration.config_key == config_key
+        )
+        db_obj = (await db.execute(stmt)).scalar_one_or_none()
         if not db_obj:
-            raise NotFoundError(f"Config Key '{config_key}' not exist")
-        db.delete(db_obj)
-        db.commit()
+            raise NotFoundError(
+                f"Config Key '{config_key}' not exist"
+            )
+
+        await db.delete(db_obj)
+        await db.commit()
         return db_obj

ファイルの差分が大きいため隠しています
+ 469 - 441
app/services/email_authorizations_service.py


+ 57 - 21
app/services/http_session_service.py

@@ -1,43 +1,79 @@
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from typing import Optional
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, delete
+
+from app.core.biz_exception import NotFoundError
 from app.models.http_session import HttpSession
 from app.schemas.http_session import HttpSessionCreate, HttpSessionUpdate
-from typing import Optional
+
 
 class HttpSessionService:
 
+    # ============================
+    # 创建 Session
+    # ============================
     @staticmethod
-    def create(db: Session, data: HttpSessionCreate) -> HttpSession:
+    async def create(db: AsyncSession, data: HttpSessionCreate) -> HttpSession:
         obj = HttpSession(**data.dict())
         db.add(obj)
-        db.commit()
-        db.refresh(obj)
+        await db.commit()
+        await db.refresh(obj)
         return obj
 
+    # ============================
+    # 根据 session_id 获取
+    # ============================
     @staticmethod
-    def get_by_sid(db: Session, session_id: str) -> Optional[HttpSession]:
-        obj = db.query(HttpSession).filter(HttpSession.session_id == session_id).first()
+    async def get_by_sid(
+        db: AsyncSession,
+        session_id: str
+    ) -> HttpSession:
+        stmt = select(HttpSession).where(HttpSession.session_id == session_id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("Session not found")
+
         return obj
 
+    # ============================
+    # 根据 session_id 删除
+    # ============================
     @staticmethod
-    def delete_by_sid(db: Session, session_id: str) -> bool:
-        obj = db.query(HttpSession).filter(HttpSession.session_id == session_id).first()
-        if not obj:
+    async def delete_by_sid(
+        db: AsyncSession,
+        session_id: str
+    ) -> bool:
+        stmt = delete(HttpSession).where(HttpSession.session_id == session_id)
+        result = await db.execute(stmt)
+
+        if result.rowcount == 0:
             raise NotFoundError("Session not found")
-        db.delete(obj)
-        db.commit()
-        return obj
 
+        await db.commit()
+        return True
+
+    # ============================
+    # 根据 session_id 更新
+    # ============================
     @staticmethod
-    def update_by_sid(db: Session, session_id: str, data: HttpSessionUpdate):
-        obj = db.query(HttpSession).filter(HttpSession.session_id == session_id).first()
+    async def update_by_sid(
+        db: AsyncSession,
+        session_id: str,
+        data: HttpSessionUpdate
+    ) -> HttpSession:
+        stmt = select(HttpSession).where(HttpSession.session_id == session_id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("Session not found")
-        for k, v in data.dict().items():
-            if v is not None:
-                setattr(obj, k, v)
-        db.commit()
-        db.refresh(obj)
+
+        for field, value in data.dict(exclude_unset=True).items():
+            setattr(obj, field, value)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj

+ 18 - 8
app/services/notification_service.py

@@ -1,15 +1,25 @@
-# app/services/product_service.py
+# app/services/notification_service.py
+
 import uuid
-from sqlalchemy.orm import Session
-from typing import Optional, List, Dict
+from typing import List, Dict, Any
+
 from redis.asyncio import Redis
-from app.utils.redis_utils import redis_qpush
+from app.utils.redis_utils import redis_qpush, redis_qpop
+
 
 class NotificationService:
 
-    def create(redis_client: Redis, ntype: str, user_id:str, channels:List[str], template_id=str, payload=Dict):
+    @staticmethod
+    async def create(
+        redis_client: Redis,
+        ntype: str,
+        user_id: str,
+        channels: List[str],
+        template_id: str,
+        payload: Dict[str, Any]
+    ) -> None:
         notification_payload = {
-            "notification_id": f'nid_{uuid.uuid4().hex}',
+            "notification_id": f"nid_{uuid.uuid4().hex}",
             "type": ntype,
             "user_id": user_id,
             "channels": channels,
@@ -17,10 +27,10 @@ class NotificationService:
             "payload": payload
         }
 
-        redis_qpush(
+        await redis_qpush(
             redis_client,
             "vas_notification_queue",
             notification_payload
         )
 
-
+    

+ 128 - 126
app/services/order_service.py

@@ -1,32 +1,42 @@
 # app/services/order_service.py
 import uuid
 import json
-from redis.asyncio import Redis
 from datetime import datetime, timedelta
-from sqlalchemy.orm import Session
-from typing import List
-from app.utils.search import apply_keyword_search
+from typing import List, Optional
+
+from redis.asyncio import Redis
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
-from app.core.auth import get_current_user
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.user import VasUser
 from app.models.order import VasOrder
 from app.models.vas_task import VasTask
 from app.models.product import VasProduct
 from app.models.product_routing import VasProductRouting
 from app.schemas.order import VasOrderCreate, VasOrderPatchUserInputs
-from app.services.notification_service import NotificationService
+
 
 class OrderService:
-    
+
+    # --------------------------------------------------
+    # 管理员强制标记为已支付
+    # --------------------------------------------------
     @staticmethod
-    def mark_as_admin_paid(db: Session, order: VasOrder, admin_user):
+    async def mark_as_admin_paid(
+        db: AsyncSession,
+        order: VasOrder,
+        admin_user: VasUser,
+    ) -> VasOrder:
+
         if order.status == "paid":
             return order
 
         order.status = "paid"
 
-        # ===== 核心修复点 =====
+        # ===== user_inputs 安全修复 =====
         raw_inputs = order.user_inputs
 
         if isinstance(raw_inputs, str):
@@ -34,12 +44,9 @@ class OrderService:
                 order.user_inputs = json.loads(raw_inputs)
             except Exception:
                 order.user_inputs = {}
-        elif raw_inputs is None:
-            order.user_inputs = {}
-        elif not isinstance(raw_inputs, dict):
+        elif raw_inputs is None or not isinstance(raw_inputs, dict):
             order.user_inputs = {}
 
-        # 记录绕过支付的原因(非常重要)
         order.user_inputs["_admin_bypass"] = {
             "enabled": True,
             "by": admin_user.id,
@@ -48,55 +55,50 @@ class OrderService:
         }
 
         db.add(order)
-        db.commit()
-        db.refresh(order)
+        await db.commit()
+        await db.refresh(order)
 
         return order
-    
+
+    # --------------------------------------------------
+    # 为订单创建任务(幂等)
+    # --------------------------------------------------
     @staticmethod
-    def create_tasks_for_order(db: Session, order: VasOrder):
-        """
-        为已支付订单创建任务(幂等)
-        """
+    async def create_tasks_for_order(
+        db: AsyncSession,
+        order: VasOrder
+    ) -> List[VasTask]:
+
         if order.status != "paid":
             return []
 
-        # ---------- 1. 查 routing ----------
-        routings = (
-            db.query(VasProductRouting)
-            .filter(
-                VasProductRouting.product_id == order.product_id,
-                VasProductRouting.is_active == 1
-            )
-            .all()
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == order.product_id,
+            VasProductRouting.is_active == 1
         )
+        result = await db.execute(stmt)
+        routings = result.scalars().all()
 
         if not routings:
             return []
 
-        created_tasks = []
+        created_tasks: List[VasTask] = []
 
         for routing in routings:
-
-            # ---------- 2. 幂等判断 ----------
-            exists = (
-                db.query(VasTask)
-                .filter(
-                    VasTask.order_id == order.id,
-                    VasTask.routing_key == routing.routing_key,
-                    VasTask.script_version == routing.script_version,
-                )
-                .first()
+            exists_stmt = select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.routing_key == routing.routing_key,
+                VasTask.script_version == routing.script_version,
             )
+            exists = (await db.execute(exists_stmt)).scalar_one_or_none()
             if exists:
                 continue
 
-            # ---------- 3. 创建 task ----------
             task = VasTask(
                 order_id=order.id,
                 routing_key=routing.routing_key,
                 script_version=routing.script_version,
-                priority=10,
+                priority=routing.priority,
                 status="pending",
                 user_inputs=order.user_inputs,
                 config=routing.config,
@@ -105,113 +107,113 @@ class OrderService:
                 expire_at=datetime.utcnow() + timedelta(days=7),
                 created_at=datetime.utcnow(),
             )
-
             db.add(task)
             created_tasks.append(task)
 
-        db.commit()
-
+        await db.commit()
         return created_tasks
-    
+
+    # --------------------------------------------------
+    # 创建订单
+    # --------------------------------------------------
     @staticmethod
-    def cancel_order(db, order_id, reason, admin_id):
-        if order.status in (OrderStatus.cancelled, OrderStatus.completed):
-            return order
+    async def create(
+        db: AsyncSession,
+        data: VasOrderCreate,
+        product: VasProduct,
+        auth_user: VasUser,
+        redis_client: Redis,
+    ) -> VasOrder:
 
-        if order.status == OrderStatus.paid:
-            raise HTTPException(
-                400,
-                "Paid order must be refunded",
+        if not auth_user.email:
+            raise BizLogicError(
+                "Your account must be linked to an email address before you can place an order."
             )
 
-        # 2️⃣ user_inputs 写入取消信息
-        user_inputs = order.user_inputs or {}
-        user_inputs["cancel"] = {
-            "reason": reason,
-            "by": "admin",
-            "admin_id": admin.user_id,
-            "at": datetime.utcnow().isoformat(),
-        }
-        order.user_inputs = user_inputs
+        order_id = f"ORD-{datetime.utcnow():%Y%m%d%H%M%S}-{uuid.uuid4().hex[:8]}"
 
-        # payment
-        for payment in order.payments:
-            if payment.status in (PaymentStatus.pending,):
-                payment.status = PaymentStatus.expired
+        order = VasOrder(
+            id=order_id,
+            **data.dict(),
+            product_name=product.title,
+            base_amount=product.price_amount,
+            base_currency=product.price_currency,
+            user_id=auth_user.id,
+        )
 
-        # task
-        for task in order.tasks:
-            task.status = TaskStatus.cancelled
+        db.add(order)
+        await db.commit()
+        await db.refresh(order)
 
         return order
-    
-    @staticmethod
-    def create(db: Session, data: VasOrderCreate, product: VasProduct, auth_user: VasUser, redis_client:Redis):
-        if not auth_user.email:
-            raise BizLogicError('Your account must be linked to an email address before you can place an order.')
-        order_id = f"ORD-{datetime.utcnow().strftime('%Y%m%d%H%M%S')}-{uuid.uuid4().hex[:8]}"
-        rec = VasOrder(id=order_id, **data.dict())
-        rec.product_name = product.title
-        rec.base_amount = product.price_amount
-        rec.base_currency = product.price_currency
-        rec.user_id = auth_user.id
-        db.add(rec)
-        db.commit()
-        db.refresh(rec)
-        
-        print(f"📧 send order created notification email")
-        NotificationService.create(
-            redis_client=redis_client,
-            ntype="order create notify",
-            user_id=auth_user.id,
-            channels=["email"],
-            template_id="order_create_notify",
-            payload={
-                "order_id": rec.id
-            }
-        )
-        return rec
 
+    # --------------------------------------------------
+    # 获取订单
+    # --------------------------------------------------
     @staticmethod
-    def get(db: Session, id: str):
-        return db.query(VasOrder).filter_by(id=id).first()
+    async def get(db: AsyncSession, order_id: str) -> Optional[VasOrder]:
+        stmt = select(VasOrder).where(VasOrder.id == order_id)
+        return (await db.execute(stmt)).scalar_one_or_none()
 
+    # --------------------------------------------------
+    # 用户订单列表
+    # --------------------------------------------------
     @staticmethod
-    def list_by_user(db: Session, user_id: str, page: int=0, size: int=10, keyword: str=None):
-        query = db.query(VasOrder).filter_by(user_id=user_id)
-        query = apply_keyword_search(
-            query=query,
+    async def list_by_user(
+        db: AsyncSession,
+        user_id: str,
+        page: int = 0,
+        size: int = 10,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasOrder).where(VasOrder.user_id == user_id)
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasOrder,
             keyword=keyword,
-            fields=["id", "user_id", "product_name"]
-        )
-        query = query.order_by(
-            VasOrder.created_at.desc()
-        )
-        return paginate(query, page, size)
-    
+            fields=["id", "user_id", "product_name"],
+        ).order_by(VasOrder.created_at.desc())
+
+        return await paginate(db, stmt, page, size)
+
+    # --------------------------------------------------
+    # 管理员订单列表
+    # --------------------------------------------------
     @staticmethod
-    def list_all(db: Session, page: int=0, size: int=10, keyword: str=None):
-        query = db.query(VasOrder)
-        query = apply_keyword_search(
-            query=query,
+    async def list_all(
+        db: AsyncSession,
+        page: int = 0,
+        size: int = 10,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasOrder)
+        query = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasOrder,
             keyword=keyword,
-            fields=["id", "user_id", "user_name", "product_name", "user_inputs"]
-        )
-        query = query.order_by(
-            VasOrder.created_at.desc()
-        )
-        return paginate(query, page, size)
-    
+            fields=["id", "user_id", "user_name", "product_name"],
+        ).order_by(VasOrder.created_at.desc())
+
+        return await paginate(db, query, page, size)
+
+    # --------------------------------------------------
+    # 更新 user_inputs
+    # --------------------------------------------------
     @staticmethod
-    def patch_user_inputs(db: Session, order_id: str, payload: VasOrderPatchUserInputs):
-        order = db.query(VasOrder).filter_by(id=order_id).first()
+    async def patch_user_inputs(
+        db: AsyncSession,
+        order_id: str,
+        payload: VasOrderPatchUserInputs,
+    ) -> VasOrder:
+
+        stmt = select(VasOrder).where(VasOrder.id == order_id)
+        order = (await db.execute(stmt)).scalar_one_or_none()
+
         if not order:
             raise NotFoundError("Order not exist")
+
         order.user_inputs = payload.user_inputs
-        db.commit()
-        db.refresh(order)
+        await db.commit()
+        await db.refresh(order)
+
         return order
-    
-    

+ 103 - 55
app/services/payment_provider_service.py

@@ -1,26 +1,34 @@
 # app/services/payment_provider.py
 
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
-from app.models.payment_provider import VasPaymentProvider
-from app.schemas.payment_provider import VasPaymentProviderCreate, VasPaymentProviderUpdate
+from typing import List, Optional
 
-class PaymentProviderSerivce:
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
 
-    def create(
-        db: Session,
-        data: VasPaymentProviderCreate
+from app.core.biz_exception import NotFoundError, BizLogicError
+from app.models.payment_provider import VasPaymentProvider
+from app.schemas.payment_provider import (
+    VasPaymentProviderCreate,
+    VasPaymentProviderUpdate,
+)
+
+
+class PaymentProviderService:
+    # --------------------------------------------------
+    # 创建支付提供商(防重复)
+    # --------------------------------------------------
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasPaymentProviderCreate,
     ) -> VasPaymentProvider:
-        # 防止重复注册
-        exists = (
-            db.query(VasPaymentProvider)
-            .filter(
-                VasPaymentProvider.name == data.name,
-                VasPaymentProvider.channel == data.channel,
-                VasPaymentProvider.currency == data.currency,
-            )
-            .first()
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.name == data.name,
+            VasPaymentProvider.channel == data.channel,
+            VasPaymentProvider.currency == data.currency,
         )
+        exists = (await db.execute(stmt)).scalar_one_or_none()
 
         if exists:
             raise BizLogicError("Payment provider already exists")
@@ -35,64 +43,104 @@ class PaymentProviderSerivce:
         )
 
         db.add(provider)
-        db.commit()
-        db.refresh(provider)
+        await db.commit()
+        await db.refresh(provider)
         return provider
 
-    def update(
-        db: Session,
+    # --------------------------------------------------
+    # 更新支付提供商(禁止修改 name/channel/currency)
+    # --------------------------------------------------
+    @staticmethod
+    async def update(
+        db: AsyncSession,
         provider_id: int,
-        data: VasPaymentProviderUpdate
+        data: VasPaymentProviderUpdate,
     ) -> VasPaymentProvider:
-        provider = db.query(VasPaymentProvider).get(provider_id)
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.id == provider_id
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+
         if not provider:
-            raise BizLogicError("Payment provider not found")
+            raise NotFoundError("Payment provider not found")
 
         update_data = data.dict(exclude_unset=True)
 
-        # 安全起见,禁止修改三元组
+        # 🚫 禁止修改三元组
         for forbidden in ("name", "channel", "currency"):
             update_data.pop(forbidden, None)
 
         for key, value in update_data.items():
             setattr(provider, key, value)
 
-        db.commit()
-        db.refresh(provider)
+        await db.commit()
+        await db.refresh(provider)
         return provider
-    
-    def delete(db: Session, id: int):
-        provider = db.query(VasPaymentProvider).filter_by(id=id).first()
+
+    # --------------------------------------------------
+    # 删除支付提供商
+    # --------------------------------------------------
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        provider_id: int,
+    ) -> bool:
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.id == provider_id
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+
         if not provider:
             raise NotFoundError("Provider not exist")
-        db.delete(provider)
-        db.commit()
-    
-    def list_all(db: Session):
-        return db.query(VasPaymentProvider).all()
-
-    def list_enabled(
-        db: Session,
-        currency: str = None
-    ):
-        q = db.query(VasPaymentProvider).filter(
+
+        await db.delete(provider)
+        await db.commit()
+        return True
+
+    # --------------------------------------------------
+    # 所有支付提供商
+    # --------------------------------------------------
+    @staticmethod
+    async def list_all(
+        db: AsyncSession,
+    ) -> List[VasPaymentProvider]:
+
+        result = await db.execute(select(VasPaymentProvider))
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 可用的支付提供商(可按币种)
+    # --------------------------------------------------
+    @staticmethod
+    async def list_enabled(
+        db: AsyncSession,
+        currency: Optional[str] = None,
+    ) -> List[VasPaymentProvider]:
+
+        stmt = select(VasPaymentProvider).where(
             VasPaymentProvider.enabled == 1
         )
 
         if currency:
-            q = q.filter(VasPaymentProvider.currency == currency)
-
-        return q.all()
-    
-    def get_by_name(
-        db: Session,
-        name: str
-    ):
-        q = db.query(VasPaymentProvider).filter(
-            VasPaymentProvider.enabled == 1
+            stmt = stmt.where(VasPaymentProvider.currency == currency)
+
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 根据 name 获取(只返回 enabled)
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_name(
+        db: AsyncSession,
+        name: str,
+    ) -> Optional[VasPaymentProvider]:
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.enabled == 1,
+            VasPaymentProvider.name == name,
         )
 
-        if name:
-            q = q.filter(VasPaymentProvider.name == name)
-
-        return q.first()
+        return (await db.execute(stmt)).scalar_one_or_none()

+ 130 - 31
app/services/payment_qr_service.py

@@ -1,50 +1,149 @@
 # app/services/payment_qr_service.py
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from typing import List
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.core.biz_exception import NotFoundError
 from app.models.payment_provider import VasPaymentProvider
 from app.models.payment_qr import VasPaymentQR
-from app.schemas.payment_qr import VasPaymentQrCreate, VasPaymentQrSetEnableIn
+from app.schemas.payment_qr import (
+    VasPaymentQrCreate,
+    VasPaymentQrSetEnableIn,
+)
+
 
 class PaymentQrService:
 
-    def create(db: Session, data: VasPaymentQrCreate):
+    # --------------------------------------------------
+    # 创建支付二维码
+    # --------------------------------------------------
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasPaymentQrCreate,
+    ) -> VasPaymentQR:
+
         rec = VasPaymentQR(**data.dict())
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
-    
-    def get_by_id(db: Session, id: int):
-        obj = db.query(VasPaymentQR).filter(VasPaymentQR.id == id).first()
+
+    # --------------------------------------------------
+    # 根据 ID 获取
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_id(
+        db: AsyncSession,
+        qr_id: int,
+    ) -> VasPaymentQR:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.id == qr_id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("QR not exist")
+
         return obj
-    
-    def set_enable(db: Session, id: int, payload: VasPaymentQrSetEnableIn):
-        obj = db.query(VasPaymentQR).filter(VasPaymentQR.id == id).first()
+
+    # --------------------------------------------------
+    # 启用 / 禁用 QR
+    # --------------------------------------------------
+    @staticmethod
+    async def set_enable(
+        db: AsyncSession,
+        qr_id: int,
+        payload: VasPaymentQrSetEnableIn,
+    ) -> VasPaymentQR:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.id == qr_id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("QR not exist")
+
         obj.is_active = payload.is_active
-        db.commit()
-        db.refresh(obj)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj
-    
-    def delete(db: Session, id: int):
-        obj = db.query(VasPaymentQR).filter(VasPaymentQR.id == id).first()
+
+    # --------------------------------------------------
+    # 删除 QR
+    # --------------------------------------------------
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        qr_id: int,
+    ) -> bool:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.id == qr_id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("QR not exist")
-        db.delete(obj)
-        db.commit()
-    
-    def get_by_devid(db: Session, devid: str):
-        return db.query(VasPaymentQR).filter(VasPaymentQR.devid == devid).all()
-
-    def get_by_provider(db: Session, provider: str):
-        return db.query(VasPaymentQR).filter(VasPaymentQR.provider == provider).all()
-    
-    def list_by_provider(db: Session, provider_id: int):
-        obj = db.query(VasPaymentProvider).filter(VasPaymentProvider.id==provider_id).first()
-        if not obj:
+
+        await db.delete(obj)
+        await db.commit()
+        return True
+
+    # --------------------------------------------------
+    # 根据设备 ID 查询
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_devid(
+        db: AsyncSession,
+        devid: str,
+    ) -> List[VasPaymentQR]:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.devid == devid
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 根据 provider 名称查询
+    # --------------------------------------------------
+    @staticmethod
+    async def get_by_provider(
+        db: AsyncSession,
+        provider: str,
+    ) -> List[VasPaymentQR]:
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    # --------------------------------------------------
+    # 根据 provider_id 查询 QR(安全校验)
+    # --------------------------------------------------
+    @staticmethod
+    async def list_by_provider(
+        db: AsyncSession,
+        provider_id: int,
+    ) -> List[VasPaymentQR]:
+
+        stmt = select(VasPaymentProvider).where(
+            VasPaymentProvider.id == provider_id
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+
+        if not provider:
             raise NotFoundError("Provider not exist")
-        
-        return db.query(VasPaymentQR).filter(VasPaymentQR.provider==obj.name).all()
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider.name
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()

+ 384 - 78
app/services/payment_service.py

@@ -1,72 +1,358 @@
 # app/services/payment_service.py
+
 import time
 import stripe
 import random
-from typing import List,Dict
+import uuid
+from typing import Dict, List, Optional
+from redis.asyncio import Redis
 from decimal import Decimal, ROUND_HALF_UP
 from datetime import datetime, timedelta
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
+from app.utils.pagination import paginate
+from app.core.biz_exception import NotFoundError, BizLogicError
+from app.models.user import VasUser
 from app.models.order import VasOrder
-from app.models.product import VasProduct
+from app.models.vas_task import VasTask
 from app.models.payment import VasPayment
+from app.models.ticket import VasTicket
+from app.models.payment_event import VasPaymentEvent
+from app.models.product_routing import VasProductRouting
+from app.models.verification_token import VasVerificationToken
 from app.models.payment_provider import VasPaymentProvider
 from app.models.payment_qr import VasPaymentQR
+from app.models.payment_confirmation import VasPaymentConfirmation
 from app.schemas.payment import VasPaymentCreate
+from app.schemas.payment_confirmation import VasPaymentConfirmationCreate, VasPaymentConfirmationUpdate
+from app.services.notification_service import NotificationService
+
 
 
 class PaymentService:
-    
+
+    # --------------------------------------------------
+    # 创建支付(统一入口)
+    # --------------------------------------------------
     @staticmethod
-    def create_payment(db: Session, payload: VasPaymentCreate, rate_table: Dict):
-        # ① 锁住订单,防止并发创建 payment
-        order = (
-            db.query(VasOrder)
-            .filter(VasOrder.id == payload.order_id)
+    async def create_payment(
+        db: AsyncSession,
+        payload: VasPaymentCreate,
+        rate_table: Dict,
+        redis_client: Redis
+    ) -> VasPayment:
+
+        # ① 锁住订单(防并发)
+        stmt = (
+            select(VasOrder)
+            .where(VasOrder.id == payload.order_id)
             .with_for_update()
-            .one()
         )
+        order = (await db.execute(stmt)).scalar_one_or_none()
+        if not order:
+            raise NotFoundError("Order not found")
 
-        # ② 是否已有进行中的 payment
-        active_payment = (
-            db.query(VasPayment)
-            .filter(
-                VasPayment.order_id == order.id,
-                VasPayment.status == "pending"
-            )
-            .first()
+        # ② 是否已有 pending payment(幂等)
+        stmt = select(VasPayment).where(
+            VasPayment.order_id == order.id,
+            VasPayment.status == "pending",
         )
+        active_payment = (await db.execute(stmt)).scalar_one_or_none()
+
         if active_payment:
             if active_payment.provider == payload.provider:
-                return active_payment  # 直接返回旧的,不报错(幂等性)
+                return active_payment
             else:
-                active_payment.status = 'failed' 
-                db.add(active_payment)
+                active_payment.status = "failed"
 
+        # ③ 根据 provider 创建
         if payload.provider in ("wechat", "alipay"):
-            payment = PaymentService.create_offline_payment(db=db, order=order, provider_name=payload.provider, rate_table=rate_table)
-            db.commit()
+            payment = await PaymentService.create_offline_payment(
+                db=db,
+                order=order,
+                provider_name=payload.provider,
+                rate_table=rate_table,
+            )
+            await db.commit()
             return payment
 
         if payload.provider == "stripe":
-            payment = PaymentService.create_stripe_payment(db=db, order=order, rate_table=rate_table)
-            db.commit()
+            payment = await PaymentService.create_stripe_payment(
+                db=db,
+                order=order,
+                rate_table=rate_table,
+            )
+            await db.commit()
             return payment
 
         raise BizLogicError("Unsupported provider")
     
     @staticmethod
-    def create_offline_payment(db, order, provider_name: str, rate_table: dict):
+    async def confirm_by_user(
+        db: AsyncSession,
+        payload: VasPaymentConfirmationCreate,
+        current_user: VasUser,
+        redis_client: Redis
+    ):
+        """
+        用户点击“我已支付”
+        """
+        # 1️⃣ 查询是否存在对应 payment 确认记录
+        result = await db.execute(
+            select(VasPaymentConfirmation)
+            .where(VasPaymentConfirmation.payment_id == payload.payment_id)
+            .where(VasPaymentConfirmation.user_id == current_user.id)
+        )
+        record = result.scalar_one_or_none()
+
+        if not record:
+            # 没有则创建一条 pending -> confirmed 记录
+            record = VasPaymentConfirmation(
+                payment_id=payload.payment_id,
+                amount=payload.amount,
+                currency=payload.currency,
+                random_offset=payload.random_offset,
+                user_id=current_user.id,
+                status="pending",
+                confirmed_at=payload.confirmed_at
+            )
+            db.add(record)
+            await db.commit()
+            await db.refresh(record)
+
+        # 2️⃣ 推送异步通知给管理员
+        # await NotificationService.create(
+        #     redis_client=redis_client,
+        #     ntype="payment_user_confirmed",
+        #     user_id=current_user.id,
+        #     channels=["wechat"],
+        #     template_id="payment_user_confirmed",
+        #     payload={
+        #         "payment_id": payload.payment_id,
+        #         "user_id": current_user.id,
+        #         "confirmed_at": record.confirmed_at.isoformat()
+        #     }
+        # )
+
+        return record
+    
+    @staticmethod
+    async def confirm_by_admin(
+        db: AsyncSession,
+        id: int,
+        payload: VasPaymentConfirmationUpdate,
+        current_user: VasUser
+    ):
+        """
+        管理员确认用户的支付
+        """
+        # 1️⃣ 查询对应确认记录
+        result = await db.execute(
+            select(VasPaymentConfirmation)
+            .where(VasPaymentConfirmation.id == id)
+        )
+        record = result.scalar_one_or_none()
+
+        if not record:
+            raise NotFoundError("Payment confirmation record not found")
+
+        # 3️⃣ 更新管理员确认状态
+        record.admin_id = current_user.id
+        record.admin_confirmed_at = datetime.utcnow()
+        record.status = 'confirmed'
+        await PaymentService._confirm_payment_action(db, record.payment_id)
+        
+        await db.commit()
+        await db.refresh(record)
+
+        return record
+    
+    @staticmethod
+    async def list_payment_confirmation(
+        db: AsyncSession,
+        keyword: Optional[str] = None,
+        page: int = 1,
+        size: int = 20,
+    ):
+        stmt = select(VasPaymentConfirmation)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
+            model=VasPaymentConfirmation,
+            keyword=keyword,
+            fields=["user_id"],
+        )
+
+        stmt = stmt.order_by(VasPaymentConfirmation.id.desc())
+
+        return await paginate(db, stmt, page, size)
+    
+    @staticmethod
+    async def _create_task_if_not_exists(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> List[VasTask]:
+
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == order.product_id,
+            VasProductRouting.is_active == 1,
+        )
+        result = await db.execute(stmt)
+        routings = result.scalars().all()
+
+        if not routings:
+            return []
+
+        created_tasks: List[VasTask] = []
+
+        for routing in routings:
+            exists_stmt = select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.routing_key == routing.routing_key,
+                VasTask.script_version == routing.script_version,
+            )
+            exists_result = await db.execute(exists_stmt)
+            exists = exists_result.scalar_one_or_none()
+
+            if exists:
+                continue
+
+            task = VasTask(
+                order_id=order.id,
+                routing_key=routing.routing_key,
+                script_version=routing.script_version,
+                priority=routing.priority,
+                status="pending",
+                user_inputs=order.user_inputs,
+                config=routing.config,
+                attempt_count=0,
+                notify_count=0,
+                expire_at=datetime.utcnow() + timedelta(days=60),
+                created_at=datetime.utcnow(),
+            )
+            db.add(task)
+            created_tasks.append(task)
+
+        return created_tasks
+    
+    @staticmethod
+    async def confirm_payment(
+        db: AsyncSession,
+        payment_id: int,
+        token: str
+    ):
+        # 校验验证码
+        stmt = select(VasVerificationToken).where(
+            VasVerificationToken.token == token,
+            VasVerificationToken.used == 0,
+        )
+        token_obj = (await db.execute(stmt)).scalar_one_or_none()
+        if not token_obj:
+            raise BizLogicError("Token invalid")
+
+        if token_obj.expire_at < datetime.utcnow():
+            raise BizLogicError("Token expired")
+        
+        payment = await PaymentService._confirm_payment_action(db, payment_id)
+        token_obj.used = 1
+        await db.commit()
+        return payment
+    
+    @staticmethod
+    async def _confirm_payment_action(db: AsyncSession, payment_id: int):
+        # ---------- 查找 payment ----------
+        pay_stmt = (
+            select(VasPayment)
+            .where(
+                VasPayment.id == payment_id,
+                VasPayment.status == "pending",
+            )
+            .order_by(VasPayment.created_at.desc())
+        )
+        pay_result = await db.execute(pay_stmt)
+        payment = pay_result.scalar_one_or_none()
+
+        if not payment:
+            raise BizLogicError("Payment not found")
+        
+        event = VasPaymentEvent(
+            provider=payment.provider,
+            event_type="payment_received",
+            title='confirm payment',
+            content='confirm payment by admin',
+            parsed_amount=payment.amount,
+            parsed_currency=payment.currency,
+            parsed_device='',
+            status="received",
+        )
+        db.add(event)
+        await db.commit()
+        await db.refresh(event)
+
+        if payment.status in ("succeeded", "late_paid"):
+            event.status = "duplicate"
+            event.matched_payment_id = payment.id
+            event.matched_order_id = payment.order_id
+            await db.commit()
+            return None
+
+        now = datetime.utcnow()
+        payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
+
+        payment.provider_payload = {
+            "title": "confirm by admin",
+            "received_at": now.isoformat(),
+        }
+
+        order_stmt = select(VasOrder).where(VasOrder.id == payment.order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
+        if order and order.status != "paid":
+            order.status = "paid"
+
+        await PaymentService._create_task_if_not_exists(db, order)
+
+        event.status = "applied"
+        event.matched_payment_id = payment.id
+        event.matched_order_id = payment.order_id
+
+        
+        await db.commit()
+        await db.refresh(payment)
+        
+        return payment
+
+    @staticmethod
+    async def create_offline_payment(
+        db: AsyncSession,
+        order: VasOrder,
+        provider_name: str,
+        rate_table: Dict,
+    ) -> VasPayment:
+
         payment = (
-            PaymentService._create_wechat_payment(db, order)
+            await PaymentService._create_wechat_payment(db, order)
             if provider_name == "wechat"
-            else PaymentService._create_alipay_payment(db, order)
+            else await PaymentService._create_alipay_payment(db, order)
         )
-        provider = db.query(VasPaymentProvider).filter(
+
+        stmt = select(VasPaymentProvider).where(
             VasPaymentProvider.enabled == 1,
-            VasPaymentProvider.name == provider_name
-        ).first()
-        qrs = db.query(VasPaymentQR).filter(VasPaymentQR.provider == provider_name).all()
+            VasPaymentProvider.name == provider_name,
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+        if not provider:
+            raise BizLogicError("Payment provider not available")
+
+        stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider_name,
+            VasPaymentQR.is_active == 1,
+        )
+        qrs = (await db.execute(stmt)).scalars().all()
         if not qrs:
             raise BizLogicError("No payment QR available")
 
@@ -79,25 +365,34 @@ class PaymentService:
         converted = (
             Decimal(payment.base_amount) * exchange_rate
         ).quantize(Decimal("1"), rounding=ROUND_HALF_UP)
-        
+
         max_discount = min(99, int(converted * Decimal("0.01")))
         discount = random.randint(1, max_discount) if max_discount >= 1 else 0
 
-        final_amount = int(converted) - discount
-
         payment.exchange_rate = exchange_rate
-        payment.amount = final_amount
+        payment.amount = int(converted) - discount
         payment.currency = provider.currency
         payment.random_offset = discount
+
         return payment
-    
+
     @staticmethod
-    def create_stripe_payment(db, order, rate_table: dict):
-        payment = PaymentService._create_stripe_payment(db, order)
-        provider = db.query(VasPaymentProvider).filter(
+    async def create_stripe_payment(
+        db: AsyncSession,
+        order: VasOrder,
+        rate_table: Dict,
+    ) -> VasPayment:
+
+        payment = await PaymentService._create_stripe_payment(db, order)
+
+        stmt = select(VasPaymentProvider).where(
             VasPaymentProvider.enabled == 1,
-            VasPaymentProvider.name == 'stripe'
-        ).first()
+            VasPaymentProvider.name == "stripe",
+        )
+        provider = (await db.execute(stmt)).scalar_one_or_none()
+        if not provider:
+            raise BizLogicError("Stripe provider not enabled")
+
         rate_key = f"{order.base_currency}->{provider.currency}".upper()
         exchange_rate = Decimal(rate_table[rate_key])
 
@@ -113,8 +408,8 @@ class PaymentService:
         stripe_session = PaymentService.create_checkout_session(
             order=order,
             payment=payment,
-            success_url="https://yourdomain.com/pay/success",
-            cancel_url="https://yourdomain.com/pay/cancel"
+            success_url="https://visafly.top/dashboard",
+            cancel_url="https://visafly.top/dashboard",
         )
 
         payment.payment_intent_id = stripe_session.id
@@ -124,22 +419,16 @@ class PaymentService:
 
     @staticmethod
     def create_checkout_session(
-        order,
-        payment,
+        order: VasOrder,
+        payment: VasPayment,
         success_url: str,
         cancel_url: str,
     ):
-        """
-        order.base_amount  单位:cent
-        payment.amount     单位:cent
-        """
-        
-        expires_at = int(time.time()) + 30 * 60  # Stripe 专用
+        expires_at = int(time.time()) + 30 * 60
 
-        session = stripe.checkout.Session.create(
+        return stripe.checkout.Session.create(
             mode="payment",
             payment_method_types=["card"],
-
             line_items=[
                 {
                     "price_data": {
@@ -152,23 +441,22 @@ class PaymentService:
                     "quantity": 1,
                 }
             ],
-
             metadata={
                 "order_id": order.id,
                 "payment_id": payment.id,
                 "user_id": order.user_id,
             },
-
-            success_url=success_url + "?session_id={CHECKOUT_SESSION_ID}",
+            success_url=success_url,
             cancel_url=cancel_url,
-
             expires_at=expires_at,
         )
 
-        return session
-
     @staticmethod
-    def _create_wechat_payment(db: Session, order: VasOrder):
+    async def _create_wechat_payment(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> VasPayment:
+
         payment = VasPayment(
             order_id=order.id,
             provider="wechat",
@@ -183,11 +471,15 @@ class PaymentService:
             expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         db.add(payment)
-        db.flush()
+        await db.flush()
         return payment
-    
+
     @staticmethod
-    def _create_alipay_payment(db: Session, order: VasOrder):
+    async def _create_alipay_payment(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> VasPayment:
+
         payment = VasPayment(
             order_id=order.id,
             provider="alipay",
@@ -202,11 +494,15 @@ class PaymentService:
             expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         db.add(payment)
-        db.flush()
+        await db.flush()
         return payment
-    
+
     @staticmethod
-    def _create_stripe_payment(db: Session, order: VasOrder):
+    async def _create_stripe_payment(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> VasPayment:
+
         payment = VasPayment(
             order_id=order.id,
             provider="stripe",
@@ -221,16 +517,26 @@ class PaymentService:
             expire_at=datetime.utcnow() + timedelta(minutes=30),
         )
         db.add(payment)
-        db.flush()
+        await db.flush()
         return payment
-    
+
     @staticmethod
-    def list_by_order(db: Session, order_id: str):
-        payments = (
-            db.query(VasPayment)
-            .filter(
-                VasPayment.order_id == order_id
-            )
-            .all()
+    async def list_by_order(
+        db: AsyncSession,
+        order_id: int,
+    ) -> List[VasPayment]:
+
+        stmt = select(VasPayment).where(
+            VasPayment.order_id == order_id
         )
-        return payments
+        result = await db.execute(stmt)
+        return result.scalars().all()
+    
+    @staticmethod
+    async def get_by_id(
+        db: AsyncSession,
+        id: int,
+    ) -> VasPayment:
+        stmt = select(VasPayment).where(VasPayment.id == id)
+        return (await db.execute(stmt)).scalar_one_or_none()
+

+ 37 - 12
app/services/product_routing_service.py

@@ -1,23 +1,48 @@
 # app/services/product_routing_service.py
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, delete
+
+from app.core.biz_exception import NotFoundError
 from app.models.product_routing import VasProductRouting
 from app.schemas.product_routing import VasProductRoutingCreate
 
+
 class ProductRoutingService:
-    def create(db: Session, data: VasProductRoutingCreate):
+
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasProductRoutingCreate,
+    ) -> VasProductRouting:
         rec = VasProductRouting(**data.dict())
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
 
-    def list_by_product(db: Session, product_id:int):
-        return db.query(VasProductRouting).filter_by(product_id=product_id).all()
-    
-    def delete(db: Session, id: int):
-        obj = db.query(VasProductRouting).filter_by(id=id).first()
+    @staticmethod
+    async def list_by_product(
+        db: AsyncSession,
+        product_id: int,
+    ):
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == product_id
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        id: int,
+    ):
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.id == id
+        )
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
             raise NotFoundError("Product routing not exist")
-        db.delete(obj)
-        db.commit()
+
+        await db.delete(obj)
+        await db.commit()

+ 63 - 30
app/services/product_service.py

@@ -1,50 +1,83 @@
 # app/services/product_service.py
-from sqlalchemy.orm import Session
-from typing import Optional, List
-from app.utils.search import apply_keyword_search
+
+from typing import Optional
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import NotFoundError
 from app.models.product import VasProduct
-from app.schemas.product import VasProductCreate, VasProductUpdate, VasProductOut
+from app.schemas.product import VasProductCreate, VasProductUpdate
+
 
 class ProductService:
 
-    def create(db: Session, data: VasProductCreate):
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasProductCreate,
+    ) -> VasProduct:
         rec = VasProduct(**data.dict())
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
 
-    def get(db: Session, id:int):
-        obj = db.query(VasProduct).filter_by(id=id).first()
+    @staticmethod
+    async def get(
+        db: AsyncSession,
+        id: int,
+    ) -> VasProduct:
+        stmt = select(VasProduct).where(VasProduct.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
-            raise NotFoundError('Product not exist')
+            raise NotFoundError("Product not exist")
         return obj
 
-    def update(db: Session, id:int, data: VasProductUpdate):
-        rec = db.query(VasProduct).filter_by(id=id).first()
+    @staticmethod
+    async def update(
+        db: AsyncSession,
+        id: int,
+        data: VasProductUpdate,
+    ) -> VasProduct:
+        stmt = select(VasProduct).where(VasProduct.id == id)
+        rec = (await db.execute(stmt)).scalar_one_or_none()
         if not rec:
-            raise NotFoundError('Product not exist')
-        for k,v in data.dict(exclude_unset=True).items():
-            setattr(rec,k,v)
-        db.commit()
-        db.refresh(rec)
+            raise NotFoundError("Product not exist")
+
+        for k, v in data.dict(exclude_unset=True).items():
+            setattr(rec, k, v)
+
+        await db.commit()
+        await db.refresh(rec)
         return rec
-    
-    def list_product(db: Session, country: str = None, visa_type: str = None, page: int=0, size: int=10, keyword: str=None):    
-        query = db.query(VasProduct)
-        
+
+    @staticmethod
+    async def list_product(
+        db: AsyncSession,
+        country: str = None,
+        visa_type: str = None,
+        page: int = 0,
+        size: int = 10,
+        keyword: str = None,
+    ):
+        # ⚠️ paginate / apply_keyword_search 仍然基于 Query
+        # 如果你当前 paginate 是同步实现,这里保持与你原项目一致
+
+        stmt = select(VasProduct)
+
         if country:
-            query = query.filter(VasProduct.country == country)
-            
+            stmt = stmt.where(VasProduct.country == country)
+
         if visa_type:
-            query = query.filter(VasProduct.visa_type == visa_type)
-        
-        query = apply_keyword_search(
-            query=query,
+            stmt = stmt.where(VasProduct.visa_type == visa_type)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasProduct,
             keyword=keyword,
-            fields=["title", "provider", "description"]
+            fields=["title", "provider", "description"],
         )
-        return paginate(query, page, size)
+
+        return await paginate(db, stmt, page, size)

+ 54 - 22
app/services/schema_service.py

@@ -1,41 +1,73 @@
 # app/services/schema_service.py
-from sqlalchemy.orm import Session
-from typing import Optional
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from typing import List
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.core.biz_exception import NotFoundError
 from app.models.schema import VasSchema
 from app.schemas.schema import VasSchemaCreate, VasSchemaUpdate
 
+
 class SchemaService:
 
-    def create(db: Session, data: VasSchemaCreate):
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: VasSchemaCreate,
+    ) -> VasSchema:
         rec = VasSchema(**data.dict())
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
 
-    def get(db: Session, id: int):
-        obj = db.query(VasSchema).filter_by(id=id).first()
+    @staticmethod
+    async def get(
+        db: AsyncSession,
+        id: int,
+    ) -> VasSchema:
+        stmt = select(VasSchema).where(VasSchema.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
-            raise NotFoundError('Schema not exist')
+            raise NotFoundError("Schema not exist")
         return obj
 
-    def update(db: Session, id: int, data: VasSchemaUpdate):
-        obj = db.query(VasSchema).filter_by(id=id).first()
+    @staticmethod
+    async def update(
+        db: AsyncSession,
+        id: int,
+        data: VasSchemaUpdate,
+    ) -> VasSchema:
+        stmt = select(VasSchema).where(VasSchema.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
-            raise NotFoundError('Schema not exist')
-        for k,v in data.dict(exclude_unset=True).items():
+            raise NotFoundError("Schema not exist")
+
+        for k, v in data.dict(exclude_unset=True).items():
             setattr(obj, k, v)
-        db.commit()
-        db.refresh(obj)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj
 
-    def delete(db: Session, id:int):
-        obj = db.query(VasSchema).filter_by(id=id).first()
+    @staticmethod
+    async def delete(
+        db: AsyncSession,
+        id: int,
+    ) -> None:
+        stmt = select(VasSchema).where(VasSchema.id == id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
-            raise NotFoundError('Schema not exist')
-        db.delete(obj)
-        db.commit()
+            raise NotFoundError("Schema not exist")
+
+        await db.delete(obj)
+        await db.commit()
 
-    def list_all(db: Session):
-        return db.query(VasSchema).all()
+    @staticmethod
+    async def list_all(
+        db: AsyncSession,
+    ) -> List[VasSchema]:
+        stmt = select(VasSchema)
+        result = await db.execute(stmt)
+        return result.scalars().all()

+ 86 - 45
app/services/seaweedfs_service.py

@@ -1,71 +1,112 @@
-import requests
+# app/services/seaweedfs_service.py
+
+import httpx
 from fastapi import UploadFile
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import BizLogicError
 from app.core.logger import logger
 
 
 class SeaweedFSService:
-    MASTER_URL = "http://127.0.0.1:9333"  # 你的 SeaweedFS master 地址
+    MASTER_URL = "http://127.0.0.1:9333"  # SeaweedFS master 地址
+    DOWNLOAD_GATEWAY = "http://45.137.220.138:8888/api/resource/download_file"
 
     @classmethod
-    def upload(cls, file: UploadFile):
-        """上传文件到 SeaweedFS"""
+    async def upload(cls, file: UploadFile):
+        """上传文件到 SeaweedFS(异步)"""
         try:
-            # 1️⃣ 获取可上传的 volume 地址
-            assign_resp = requests.get(f"{cls.MASTER_URL}/dir/assign", timeout=5)
-            assign_data = assign_resp.json()
-            fid = assign_data["fid"]
-            public_url = assign_data["publicUrl"]
-
-            # 2️⃣ 上传文件数据
-            upload_url = f"http://{public_url}/{fid}"
-            
-            download_url = f"http://45.137.220.138:8888/api/resource/download_file?fid={fid}"
-            files = {"file": (file.filename, file.file, file.content_type)}
-            upload_resp = requests.post(upload_url, files=files, timeout=10)
-
-            if upload_resp.status_code == 201:
-                return {"fid": fid, "url": download_url}
-            else:
+            async with httpx.AsyncClient(timeout=10) as client:
+                # 1️⃣ 获取 volume
+                assign_resp = await client.get(f"{cls.MASTER_URL}/dir/assign")
+                assign_resp.raise_for_status()
+                assign_data = assign_resp.json()
+
+                fid = assign_data["fid"]
+                public_url = assign_data["publicUrl"]
+
+                upload_url = f"http://{public_url}/{fid}"
+                download_url = f"{cls.DOWNLOAD_GATEWAY}?fid={fid}"
+
+                # 2️⃣ 上传文件
+                files = {
+                    "file": (
+                        file.filename,
+                        await file.read(),
+                        file.content_type,
+                    )
+                }
+
+                upload_resp = await client.post(upload_url, files=files)
+
+                if upload_resp.status_code == 201:
+                    return {
+                        "fid": fid,
+                        "url": download_url,
+                    }
+
                 raise BizLogicError(f"file upload error: {upload_resp.text}")
+
         except Exception as e:
+            logger.exception("SeaweedFS upload failed")
             raise BizLogicError(f"file upload exception: {e}")
 
     @classmethod
-    def get(cls, fid: str):
-        """根据 fid 读取文件"""
+    async def get(cls, fid: str):
+        """根据 fid 读取文件(异步)"""
         try:
-            resp = requests.get(f"{cls.MASTER_URL}/dir/lookup?volumeId={fid.split(',')[0]}", timeout=5)
-            data = resp.json()
-            if not data.get("locations"):
-                return None
+            volume_id = fid.split(",")[0]
+
+            async with httpx.AsyncClient(timeout=10) as client:
+                lookup_resp = await client.get(
+                    f"{cls.MASTER_URL}/dir/lookup",
+                    params={"volumeId": volume_id},
+                )
+                lookup_resp.raise_for_status()
+                data = lookup_resp.json()
+
+                if not data.get("locations"):
+                    return None
+
+                public_url = data["locations"][0]["publicUrl"]
+                file_url = f"http://{public_url}/{fid}"
 
-            public_url = data["locations"][0]["publicUrl"]
-            file_url = f"http://{public_url}/{fid}"
+                file_resp = await client.get(file_url)
+                if file_resp.status_code == 200:
+                    return (
+                        file_resp.content,
+                        file_resp.headers.get(
+                            "Content-Type", "application/octet-stream"
+                        ),
+                    )
 
-            file_resp = requests.get(file_url, timeout=10)
-            if file_resp.status_code == 200:
-                return file_resp.content, file_resp.headers.get("Content-Type", "application/octet-stream")
-            else:
                 return None
+
         except Exception as e:
-            logger.exception(f"SeaweedFS 读取异常, 原因={e}")
+            logger.exception(f"SeaweedFS get failed, reason={e}")
             return None
 
     @classmethod
-    def delete(cls, fid: str):
-        """删除文件"""
+    async def delete(cls, fid: str) -> bool:
+        """删除文件(异步)"""
         try:
-            resp = requests.get(f"{cls.MASTER_URL}/dir/lookup?volumeId={fid.split(',')[0]}", timeout=5)
-            data = resp.json()
-            if not data.get("locations"):
-                return False
+            volume_id = fid.split(",")[0]
+
+            async with httpx.AsyncClient(timeout=10) as client:
+                lookup_resp = await client.get(
+                    f"{cls.MASTER_URL}/dir/lookup",
+                    params={"volumeId": volume_id},
+                )
+                lookup_resp.raise_for_status()
+                data = lookup_resp.json()
+
+                if not data.get("locations"):
+                    return False
+
+                public_url = data["locations"][0]["publicUrl"]
+                delete_url = f"http://{public_url}/{fid}"
 
-            public_url = data["locations"][0]["publicUrl"]
-            delete_url = f"http://{public_url}/{fid}"
+                del_resp = await client.delete(delete_url)
+                return del_resp.status_code == 202
 
-            del_resp = requests.delete(delete_url, timeout=10)
-            return del_resp.status_code == 202
         except Exception as e:
-            logger.exception(f"SeaweedFS 删除异常, 原因={e}")
+            logger.exception(f"SeaweedFS delete failed, reason={e}")
             return False

+ 26 - 20
app/services/session_service.py

@@ -1,10 +1,10 @@
 # app/services/session_service.py
 
-from datetime import datetime, timedelta
-import uuid
+from datetime import datetime
+from typing import Optional
 
-from sqlalchemy.orm import Session as DBSession
-from sqlalchemy import delete
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, delete
 
 from app.models.session import VasSession
 from app.models.user import VasUser
@@ -12,28 +12,34 @@ from app.models.user import VasUser
 
 class SessionService:
 
-    # ============================
-    # token → user(鉴权用)
-    # ============================
     @staticmethod
-    def get_user_by_token(db: DBSession, session_id: str) -> VasUser:
-        session_obj = db.query(VasSession).filter(VasSession.id == session_id).first()
+    async def get_user_by_token(
+        db: AsyncSession,
+        session_id: str
+    ) -> Optional[VasUser]:
+        result = await db.execute(
+            select(VasSession).where(VasSession.id == session_id)
+        )
+        session_obj = result.scalar_one_or_none()
+
         if not session_obj:
             return None
 
-        # session 是否过期
         if session_obj.expire_at < datetime.utcnow():
-            # 自动删除过期 session
-            SessionService.delete_session(db, session_id)
+            await SessionService.delete_session(db, session_id)
             return None
 
-        user = db.query(VasUser).filter(VasUser.id == session_obj.user_id).first()
-        return user
+        result = await db.execute(
+            select(VasUser).where(VasUser.id == session_obj.user_id)
+        )
+        return result.scalar_one_or_none()
 
-    # ============================
-    # 删除 session(登出)
-    # ============================
     @staticmethod
-    def delete_session(db: DBSession, session_id: str):
-        db.query(VasSession).filter(VasSession.id == session_id).delete()
-        db.commit()
+    async def delete_session(
+        db: AsyncSession,
+        session_id: str
+    ) -> None:
+        await db.execute(
+            delete(VasSession).where(VasSession.id == session_id)
+        )
+        await db.commit()

+ 45 - 18
app/services/short_url_service.py

@@ -1,40 +1,67 @@
+# app/services/short_url_service.py
+
 import string
 import random
-from sqlalchemy.orm import Session
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+from sqlalchemy.exc import IntegrityError
+
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.short_url import ShortUrl
 
 
 class ShortUrlService:
+
     @staticmethod
     def generate_short_key(length: int = 8) -> str:
-        """生成随机短 Key(字母+数字组成)"""
+        """生成随机短 Key(字母 + 数字)"""
         chars = string.ascii_letters + string.digits
         return ''.join(random.choices(chars, k=length))
 
     @staticmethod
-    def create_short_url(db: Session, long_url: str) -> ShortUrl:
-        """创建短链接"""
-        # 检查是否已经存在相同的长链接
-        existing = db.query(ShortUrl).filter(ShortUrl.long_url == long_url).first()
+    async def create_short_url(
+        db: AsyncSession,
+        long_url: str
+    ) -> ShortUrl:
+        """创建短链接(异步 + 并发安全)"""
+
+        # 1️⃣ 是否已有相同 long_url
+        stmt = select(ShortUrl).where(ShortUrl.long_url == long_url)
+        existing = await db.scalar(stmt)
         if existing:
             raise BizLogicError("Short url already exist")
 
-        # 生成唯一 short_key
-        short_key = ShortUrlService.generate_short_key()
-        while db.query(ShortUrl).filter(ShortUrl.short_key == short_key).first():
+        # 2️⃣ 生成 short_key(循环直到成功)
+        while True:
             short_key = ShortUrlService.generate_short_key()
 
-        db_obj = ShortUrl(short_key=short_key, long_url=long_url)
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
-        return db_obj
+            db_obj = ShortUrl(
+                short_key=short_key,
+                long_url=long_url
+            )
+            db.add(db_obj)
+
+            try:
+                await db.commit()
+                await db.refresh(db_obj)
+                return db_obj
+
+            except IntegrityError:
+                # short_key 冲突,重试
+                await db.rollback()
+                continue
 
     @staticmethod
-    def get_long_url(db: Session, short_key: str) -> str:
-        """通过短 key 获取原始长链接"""
-        record = db.query(ShortUrl).filter(ShortUrl.short_key == short_key).first()
+    async def get_long_url(
+        db: AsyncSession,
+        short_key: str
+    ) -> str:
+        """通过短 key 获取原始长链接(异步)"""
+
+        stmt = select(ShortUrl).where(ShortUrl.short_key == short_key)
+        record = await db.scalar(stmt)
+
         if not record:
             raise NotFoundError("Short url not found")
+
         return record.long_url

+ 30 - 7
app/services/slot_snapshot_service.py

@@ -1,17 +1,40 @@
 # app/services/slot_snapshot_service.py
-from sqlalchemy.orm import Session
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
 from app.models.slot_snapshot import VasSlotSnapshot
 from app.schemas.slot_snapshot import SlotSnapshotCreate
-from datetime import datetime
+
 
 class SlotSnapshotService:
 
-    def create(db: Session, data: SlotSnapshotCreate):
+    @staticmethod
+    async def create(
+        db: AsyncSession,
+        data: SlotSnapshotCreate
+    ) -> VasSlotSnapshot:
         rec = VasSlotSnapshot(**data.dict())
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
 
-    def latest_for(db: Session, country:str, city:str, visa_type:str):
-        return db.query(VasSlotSnapshot).filter_by(country=country, city=city, visa_type=visa_type).order_by(VasSlotSnapshot.snapshot_at.desc()).first()
+    @staticmethod
+    async def latest_for(
+        db: AsyncSession,
+        country: str,
+        city: str,
+        visa_type: str
+    ) -> VasSlotSnapshot:
+        stmt = (
+            select(VasSlotSnapshot)
+            .where(
+                VasSlotSnapshot.country == country,
+                VasSlotSnapshot.city == city,
+                VasSlotSnapshot.visa_type == visa_type,
+            )
+            .order_by(VasSlotSnapshot.snapshot_at.desc())
+            .limit(1)
+        )
+
+        return await db.scalar(stmt)

+ 52 - 22
app/services/sms_service.py

@@ -1,56 +1,86 @@
+# app/services/sms_service.py
+
 import json
-import time
-import requests
 from typing import List
+from redis.asyncio import Redis
 from app.schemas.sms import ShortMessageDetail
 
 
-def save_short_message(redis_client, phone: str, message: str, received_at: str, max_ttl: int) -> ShortMessageDetail:
+async def save_short_message(
+    redis_client: Redis,
+    phone: str,
+    message: str,
+    received_at: str,
+    max_ttl: int
+) -> ShortMessageDetail:
     """
-    将短信保存到 Redis。
-    键格式: sms:{phone}
-    值为 JSON 数组(保存最近多条短信)
+    将短信保存到 Redis(异步版)
+    key: sms:{phone}
+    value: JSON 数组(最多保留最近 20 条
     """
     key = f"sms:{phone}"
 
-    # 取出已有数据
-    existing_data = redis_client.get(key)
+    # 1️⃣ 读取已有短信
+    existing_data = await redis_client.get(key)
     if existing_data:
         messages = json.loads(existing_data)
     else:
         messages = []
 
-    # 添加新短信
-    new_msg = ShortMessageDetail(phone=phone, message=message, received_at=received_at)
+    # 2️⃣ 添加新短信
+    new_msg = ShortMessageDetail(
+        phone=phone,
+        message=message,
+        received_at=received_at
+    )
     messages.append(new_msg.dict())
 
-    # 最多保留最近 20 条(可调整)
+    # 3️⃣ 保留最近 20 条
     messages = messages[-20:]
 
-    # 保存回 Redis
-    redis_client.setex(key, max_ttl, json.dumps(messages))
+    # 4️⃣ 写回 Redis(重置 TTL)
+    await redis_client.setex(
+        key,
+        max_ttl,
+        json.dumps(messages)
+    )
 
     return new_msg
 
 
-def query_short_message(redis_client, phone: str, keyword: str, sent_at: str) -> List[ShortMessageDetail]:
+async def query_short_message(
+    redis_client: Redis,
+    phone: str,
+    keyword: str = None,
+    sent_at: str = None
+) -> List[ShortMessageDetail]:
     """
-    从 Redis 查询短信。
-    支持按关键字和时间过滤。
+    从 Redis 查询短信(异步版)
+    支持关键字 / 时间过滤
     """
     key = f"sms:{phone}"
-    existing_data = redis_client.get(key)
+
+    existing_data = await redis_client.get(key)
     if not existing_data:
         return []
 
-    messages = [ShortMessageDetail(**m) for m in json.loads(existing_data)]
+    messages = [
+        ShortMessageDetail(**m)
+        for m in json.loads(existing_data)
+    ]
 
     # 关键字过滤
     if keyword:
-        messages = [m for m in messages if keyword in m.message]
+        messages = [
+            m for m in messages
+            if keyword in m.message
+        ]
 
-    # 时间过滤(如果需要更复杂可转为时间戳比较)
+    # 时间过滤(字符串比较,ISO8601 安全
     if sent_at:
-        messages = [m for m in messages if m.received_at >= sent_at]
+        messages = [
+            m for m in messages
+            if m.received_at >= sent_at
+        ]
 
-    return messages
+    return messages

+ 157 - 90
app/services/statistics_service.py

@@ -1,5 +1,7 @@
-from sqlalchemy.orm import Session
-from sqlalchemy import func, desc, and_, case
+# app/services/statistics_service.py
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import func, desc, case, select
 from datetime import datetime, timedelta, date
 from typing import Dict, Any
 
@@ -9,8 +11,11 @@ from app.models.vas_task import VasTask
 from app.models.user import VasUser
 from app.models.product import VasProduct
 
-# 静态汇率配置 (基准: CNY)
-# 实际生产环境建议从数据库或缓存获取实时汇率
+
+# ======================
+# 汇率 & 货币符号
+# ======================
+
 EXCHANGE_RATES = {
     "CNY": 1.0,
     "USD": 7.25,
@@ -24,161 +29,223 @@ CURRENCY_SYMBOLS = {
     "CNY": "¥", "USD": "$", "EUR": "€", "GBP": "£", "HKD": "HK$", "JPY": "¥"
 }
 
+
 class StatisticsService:
+
+    # ======================
+    # 工具方法
+    # ======================
+
     @staticmethod
     def _convert_to_cny(amount: any, currency: str) -> int:
-        """
-        辅助函数:将金额(分)转换为人民币(分)
-        """
+        """金额(分) → CNY(分)"""
         if not amount:
             return 0
-        rate = EXCHANGE_RATES.get(currency, 1.0) # 未知货币默认按 1:1 处理
-        
-        # 修复点:将 amount 转换为 float,解决 Decimal * float 报错的问题
+        rate = EXCHANGE_RATES.get(currency, 1.0)
         return int(float(amount) * rate)
 
+    # ======================
+    # 核心接口
+    # ======================
+
     @staticmethod
-    def overview(db: Session) -> Dict[str, Any]:
+    async def overview(db: AsyncSession) -> Dict[str, Any]:
         """
-        获取后台概览数据 (统一换算为 CNY 统计)
+        后台统计概览(Async 版)
         """
-        # --- 1. 核心指标卡片 ---
-        
-        # 1.1 总营收 (按币种分组求和,再换算)
-        revenue_groups = db.query(
-            VasOrder.base_currency,
-            func.sum(VasOrder.base_amount)
-        ).filter(
-            VasOrder.status.in_(['paid', 'completed', 'succeeded'])
-        ).group_by(VasOrder.base_currency).all()
-
-        total_revenue_cny = 0
-        for currency, amount in revenue_groups:
-            total_revenue_cny += StatisticsService._convert_to_cny(amount, currency)
+
+        # --------------------------------------------------
+        # 1. 核心指标
+        # --------------------------------------------------
+
+        # 1.1 总营收(按币种分组)
+        revenue_stmt = (
+            select(
+                VasOrder.base_currency,
+                func.sum(VasOrder.base_amount)
+            )
+            .where(VasOrder.status.in_(["paid", "completed", "succeeded"]))
+            .group_by(VasOrder.base_currency)
+        )
+
+        revenue_rows = (await db.execute(revenue_stmt)).all()
+
+        total_revenue_cny = sum(
+            StatisticsService._convert_to_cny(amount, currency)
+            for currency, amount in revenue_rows
+        )
 
         # 1.2 活跃订单数
-        total_orders = db.query(func.count(VasOrder.id))\
-            .filter(VasOrder.status != 'closed')\
-            .scalar() or 0
+        total_orders = (
+            await db.scalar(
+                select(func.count(VasOrder.id))
+                .where(VasOrder.status != "closed")
+            )
+        ) or 0
 
         # 1.3 活跃用户数
-        active_users = db.query(func.count(VasUser.id)).scalar() or 0
+        active_users = (
+            await db.scalar(select(func.count(VasUser.id)))
+        ) or 0
 
         # 1.4 待处理工单
-        pending_tickets = db.query(func.count(VasTicket.id))\
-            .filter(VasTicket.status.in_(['pending', 'info_required']))\
-            .scalar() or 0
+        pending_tickets = (
+            await db.scalar(
+                select(func.count(VasTicket.id))
+                .where(VasTicket.status.in_(["pending", "info_required"]))
+            )
+        ) or 0
 
         # 1.5 任务成功率
-        task_counts = db.query(
-            func.count(VasTask.id).label('total'),
-            func.sum(case((VasTask.status == 'completed', 1), else_=0)).label('success')
-        ).first()
-        
+        task_stmt = select(
+            func.count(VasTask.id).label("total"),
+            func.sum(
+                case((VasTask.status == "completed", 1), else_=0)
+            ).label("success")
+        )
+
+        task_counts = (await db.execute(task_stmt)).first()
+
         success_rate_str = "0%"
         if task_counts and task_counts.total > 0:
             rate = (task_counts.success / task_counts.total) * 100
             success_rate_str = f"{rate:.1f}%"
 
-        # --- 2. 营收趋势图 (Last 7 Days) ---
-        
+        # --------------------------------------------------
+        # 2. 最近 7 天营收趋势
+        # --------------------------------------------------
+
         revenue_trend = []
         today = date.today()
-        
+
         for i in range(6, -1, -1):
             target_date = today - timedelta(days=i)
             start_dt = datetime.combine(target_date, datetime.min.time())
             end_dt = datetime.combine(target_date, datetime.max.time())
-            
-            # 按币种分组查询当天的营收
-            daily_groups = db.query(
-                VasOrder.base_currency,
-                func.sum(VasOrder.base_amount).label('amount'),
-                func.count(VasOrder.id).label('orders')
-            ).filter(
-                VasOrder.created_at >= start_dt,
-                VasOrder.created_at <= end_dt,
-                VasOrder.status.in_(['paid', 'completed', 'succeeded'])
-            ).group_by(VasOrder.base_currency).all()
+
+            daily_stmt = (
+                select(
+                    VasOrder.base_currency,
+                    func.sum(VasOrder.base_amount).label("amount"),
+                    func.count(VasOrder.id).label("orders")
+                )
+                .where(
+                    VasOrder.created_at >= start_dt,
+                    VasOrder.created_at <= end_dt,
+                    VasOrder.status.in_(["paid", "completed", "succeeded"])
+                )
+                .group_by(VasOrder.base_currency)
+            )
+
+            daily_rows = (await db.execute(daily_stmt)).all()
 
             daily_amount_cny = 0
             daily_order_count = 0
-            
-            for curr, amt, cnt in daily_groups:
+
+            for curr, amt, cnt in daily_rows:
                 daily_amount_cny += StatisticsService._convert_to_cny(amt, curr)
                 daily_order_count += cnt
 
             revenue_trend.append({
                 "date": target_date.strftime("%m-%d"),
-                "amount": float(daily_amount_cny) / 100.0, # 转为元 (浮点数)
+                "amount": daily_amount_cny / 100.0,
                 "orders": daily_order_count
             })
 
-        # --- 3. 商品销量分布 ---
-        
-        product_stats = db.query(
-            VasProduct.title,
-            func.count(VasOrder.id).label('count')
-        ).join(VasOrder, VasOrder.product_id == VasProduct.id)\
-         .filter(VasOrder.status.in_(['paid', 'completed', 'succeeded']))\
-         .group_by(VasProduct.title)\
-         .order_by(desc('count'))\
-         .limit(5).all()
-
-        product_dist = [{"name": p.title, "value": p.count} for p in product_stats]
-
-        # --- 4. 最新动态 ---
-        
+        # --------------------------------------------------
+        # 3. 商品销量分布(Top 5)
+        # --------------------------------------------------
+
+        product_stmt = (
+            select(
+                VasProduct.title,
+                func.count(VasOrder.id).label("count")
+            )
+            .join(VasOrder, VasOrder.product_id == VasProduct.id)
+            .where(VasOrder.status.in_(["paid", "completed", "succeeded"]))
+            .group_by(VasProduct.title)
+            .order_by(desc("count"))
+            .limit(5)
+        )
+
+        product_rows = (await db.execute(product_stmt)).all()
+
+        product_dist = [
+            {"name": title, "value": count}
+            for title, count in product_rows
+        ]
+
+        # --------------------------------------------------
+        # 4. 最新动态
+        # --------------------------------------------------
+
         activities = []
-        
-        # 订单动态
-        recent_orders = db.query(VasOrder).order_by(desc(VasOrder.created_at)).limit(5).all()
+
+        # 最近订单
+        order_stmt = (
+            select(VasOrder)
+            .order_by(desc(VasOrder.created_at))
+            .limit(5)
+        )
+        recent_orders = (await db.execute(order_stmt)).scalars().all()
+
         for o in recent_orders:
             symbol = CURRENCY_SYMBOLS.get(o.base_currency, o.base_currency)
             amt_display = f"{symbol}{o.base_amount / 100}"
-            
+
             activities.append({
                 "id": f"order_{o.id}",
                 "text": f"用户下单: {o.product_name or '未知商品'} ({amt_display})",
                 "time": o.created_at,
-                "type": "order" if o.status == 'pending' else "money"
+                "type": "order" if o.status == "pending" else "money"
             })
 
-        # 工单动态
-        recent_tickets = db.query(VasTicket).order_by(desc(VasTicket.created_at)).limit(5).all()
+        # 最近工单
+        ticket_stmt = (
+            select(VasTicket)
+            .order_by(desc(VasTicket.created_at))
+            .limit(5)
+        )
+        recent_tickets = (await db.execute(ticket_stmt)).scalars().all()
+
         for t in recent_tickets:
-            reason_preview = t.reason[:20] + "..." if len(t.reason) > 20 else t.reason
+            reason_preview = (
+                t.reason[:20] + "..."
+                if t.reason and len(t.reason) > 20
+                else t.reason
+            )
             activities.append({
                 "id": f"ticket_{t.id}",
                 "text": f"新工单 #{t.id}: {reason_preview}",
                 "time": t.created_at,
                 "type": "ticket"
             })
-        
-        # 排序与时间格式
-        activities.sort(key=lambda x: x['time'], reverse=True)
+
+        # 排序 + 时间人性
+        activities.sort(key=lambda x: x["time"], reverse=True)
         activities = activities[:10]
 
         now = datetime.now()
         for act in activities:
-            dt = act['time']
-            if not isinstance(dt, datetime):
-                 continue
+            dt = act["time"]
             diff = now - dt
             if diff.days > 0:
-                time_str = f"{diff.days}天前"
+                act["time"] = f"{diff.days}天前"
             elif diff.seconds > 3600:
-                time_str = f"{diff.seconds // 3600}小时前"
+                act["time"] = f"{diff.seconds // 3600}小时前"
             elif diff.seconds > 60:
-                time_str = f"{diff.seconds // 60}分钟前"
+                act["time"] = f"{diff.seconds // 60}分钟前"
             else:
-                time_str = "刚刚"
-            act['time'] = time_str
+                act["time"] = "刚刚"
+
+        # --------------------------------------------------
+        # 返回结果
+        # --------------------------------------------------
 
         return {
             "stats": {
                 "totalOrders": total_orders,
-                "totalRevenue": total_revenue_cny, # 单位:分 (CNY)
+                "totalRevenue": total_revenue_cny,  # CNY 分
                 "activeUsers": active_users,
                 "pendingTickets": pending_tickets,
                 "successRate": success_rate_str
@@ -186,4 +253,4 @@ class StatisticsService:
             "revenue_trend": revenue_trend,
             "product_dist": product_dist,
             "recent_activities": activities
-        }
+        }

+ 52 - 18
app/services/task_service.py

@@ -1,34 +1,56 @@
-import json
-from sqlalchemy.orm import Session
+# app/services/task_service.py
+
 from typing import List, Optional
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select, update
+
+from app.core.biz_exception import NotFoundError
 from app.models.task import Task
 from app.schemas.task import TaskCreate, TaskUpdate
 
 
 class TaskService:
+
+    # ======================
+    # 创建任务
+    # ======================
     @staticmethod
-    def create(db: Session, obj_in: TaskCreate) -> Task:
+    async def create(db: AsyncSession, obj_in: TaskCreate) -> Task:
         db_obj = Task(
             command=obj_in.command,
             args=obj_in.args,
             status=obj_in.status or 0,
         )
         db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
 
+    # ======================
+    # 根据 ID 获取
+    # ======================
     @staticmethod
-    def get_by_id(db: Session, task_id: int) -> Optional[Task]:
-        obj = db.query(Task).filter(Task.id == task_id).first()
+    async def get_by_id(db: AsyncSession, task_id: int) -> Task:
+        stmt = select(Task).where(Task.id == task_id)
+        obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("Task not exist")
         return obj
 
+    # ======================
+    # 更新任务
+    # ======================
     @staticmethod
-    def update(db: Session, task_id: int, obj_in: TaskUpdate) -> Optional[Task]:
-        db_obj = db.query(Task).filter(Task.id == task_id).first()
+    async def update(
+        db: AsyncSession,
+        task_id: int,
+        obj_in: TaskUpdate
+    ) -> Task:
+        stmt = select(Task).where(Task.id == task_id)
+        db_obj = (await db.execute(stmt)).scalar_one_or_none()
+
         if not db_obj:
             raise NotFoundError("Task not exist")
 
@@ -37,19 +59,31 @@ class TaskService:
         if obj_in.status is not None:
             db_obj.status = obj_in.status
 
-        db.add(db_obj)
-        db.commit()
-        db.refresh(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
         return db_obj
 
+    # ======================
+    # 获取待处理任务(分页)
+    # ======================
     @staticmethod
-    def get_pending(db: Session, command: str, page: int, size: int) -> List[Task]:
+    async def get_pending(
+        db: AsyncSession,
+        command: str,
+        page: int,
+        size: int
+    ) -> List[Task]:
         offset = page * size
-        return (
-            db.query(Task)
-            .filter(Task.command == command, Task.status == 0)
+
+        stmt = (
+            select(Task)
+            .where(
+                Task.command == command,
+                Task.status == 0
+            )
             .order_by(Task.create_at.asc())
             .offset(offset)
             .limit(size)
-            .all()
         )
+
+        return (await db.execute(stmt)).scalars().all()

+ 17 - 12
app/services/telegram_service.py

@@ -1,21 +1,26 @@
-import json
-import time
-import requests
-from typing import List
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+# app/services/telegram_service.py
+
+import httpx
+from app.core.biz_exception import BizLogicError
 from app.schemas.telegram import TelegramIn
 
 
 class TelegramService:
-    def push_to_telegram(payload: TelegramIn):
+
+    @staticmethod
+    async def push_to_telegram(payload: TelegramIn):
         url = f"https://api.telegram.org/bot{payload.api_token}/sendMessage"
-        payload = {
+
+        body = {
             "chat_id": payload.chat_id,
             "text": payload.message,
-            "parse_mode": "HTML"
+            "parse_mode": "HTML",
         }
 
-        response = requests.post(url, json=payload, timeout=10)
-        if response.status_code != 200:
-            raise BizLogicError("Telegram push failed")
-    
+        async with httpx.AsyncClient(timeout=10) as client:
+            resp = await client.post(url, json=body)
+
+        if resp.status_code != 200:
+            raise BizLogicError(
+                f"Telegram push failed: {resp.status_code}, {resp.text}"
+            )

+ 238 - 78
app/services/ticket_service.py

@@ -1,12 +1,17 @@
-# app/services/ticket_service.py
 from datetime import datetime
+from typing import List, Optional
+
 from redis.asyncio import Redis
-from typing import List
-from sqlalchemy.orm import Session
-from app.utils.search import apply_keyword_search
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
 from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
 from app.models.user import VasUser
+from app.models.order import VasOrder
+from app.models.vas_task import VasTask
+from app.models.payment import VasPayment
 from app.models.ticket import VasTicket
 from app.models.ticket_message import VasTicketMessage
 from app.schemas.ticket import VasTicketCreate
@@ -14,135 +19,290 @@ from app.services.notification_service import NotificationService
 
 
 class TicketService:
-    
+
     @staticmethod
-    def create(db: Session, data: VasTicketCreate, current_user: VasUser, redis_client: Redis):
-        rec = VasTicket(**data.dict(), status='pending', created_at=datetime.utcnow())
-        rec.user_id = current_user.id
-        db.add(rec)
-        db.commit()
-        db.refresh(rec)
-        print(f"📧 send ticket created notification email")
-        NotificationService.create(
+    async def create(
+        db: AsyncSession,
+        data: VasTicketCreate,
+        current_user: VasUser,
+        redis_client: Redis
+    ):
+        ticket = VasTicket(
+            **data.dict(),
+            user_id=current_user.id,
+            status="pending",
+            created_at=datetime.utcnow(),
+            updated_at=datetime.utcnow(),
+        )
+
+        db.add(ticket)
+        await db.commit()
+        await db.refresh(ticket)
+
+        await NotificationService.create(
             redis_client=redis_client,
             ntype="ticket created",
             user_id=current_user.id,
             channels=["email"],
             template_id="ticket_created",
             payload={
-                "ticket_id": rec.id,
-                "order_id": rec.order_id
-            }
+                "ticket_id": ticket.id,
+            },
         )
-        return rec
+
+        return ticket
 
     @staticmethod
-    def update_status(db: Session, ticket_id, status, comment, admin_id):
-        ticket = db.query(VasTicket).filter_by(id=ticket_id).first()
+    async def update_status(
+        db: AsyncSession,
+        ticket_id: int,
+        status: str,
+        comment: str,
+        admin_id: str,
+    ) -> VasTicket:
+
+        # 🔒 锁住 ticket,防并发管理员操作
+        result = await db.execute(
+            select(VasTicket)
+            .where(VasTicket.id == ticket_id)
+            .with_for_update()
+        )
+        ticket = result.scalar_one_or_none()
+
         if not ticket:
             raise NotFoundError("Ticket not exist")
+
         ticket.status = status
         ticket.admin_comment = comment
+        ticket.updated_at = datetime.utcnow()
+
         db.add(
             VasTicketMessage(
-                ticket_id=ticket_id,
+                ticket_id=ticket.id,
                 sender_type="admin",
-                sender_id=admin_id,
-                content=comment
+                sender_id=str(admin_id),
+                content=comment,
+                created_at=datetime.utcnow(),
             )
         )
-        db.commit()
-        db.refresh(ticket)
+
+        if status == "resolved":
+            await TicketService._handle_resolution(db, ticket)
+
+        elif status == "rejected":
+            await TicketService._handle_rejection(db, ticket)
+
+        await db.commit()
+        await db.refresh(ticket)
         return ticket
-        
+
+    # =========================
+    # 工单解决逻辑
+    # =========================
+    @staticmethod
+    async def _handle_resolution(
+        db: AsyncSession,
+        ticket: VasTicket,
+    ) -> None:
+
+        if not ticket.order_id:
+            return
+
+        result = await db.execute(
+            select(VasOrder)
+            .where(VasOrder.id == ticket.order_id)
+            .with_for_update()
+        )
+        order = result.scalar_one_or_none()
+        if not order:
+            return
+
+        # ---------- 退款 ----------
+        if ticket.type == "refund":
+            order.status = "closed"
+
+            pay_res = await db.execute(
+                select(VasPayment).where(
+                    VasPayment.order_id == order.id,
+                    VasPayment.status.in_(["succeeded", "late_paid"]),
+                )
+            )
+            payment = pay_res.scalar_one_or_none()
+            if payment:
+                payment.status = "refunded"
+                payment.refunded_at = datetime.utcnow()
+                payment.refund_reason = ticket.reason
+
+            task_res = await db.execute(
+                select(VasTask).where(
+                    VasTask.order_id == order.id,
+                    VasTask.status.in_(["pending", "grabbed", "running"]),
+                )
+            )
+            for task in task_res.scalars().all():
+                task.status = "cancelled"
+
+        # ---------- 变更请求 ----------
+        elif ticket.type == "change_request":
+
+            # 1️⃣ 取消旧任务
+            task_res = await db.execute(
+                select(VasTask).where(
+                    VasTask.order_id == order.id,
+                    VasTask.status.in_(["pending", "grabbed", "running"]),
+                )
+            )
+            for task in task_res.scalars().all():
+                task.status = "cancelled"
+
+            # 2️⃣ 重新生成任务(幂等)
+            routing_res = await db.execute(
+                select(VasProductRouting).where(
+                    VasProductRouting.product_id == order.product_id,
+                    VasProductRouting.is_active == 1,
+                )
+            )
+            routings = routing_res.scalars().all()
+
+            for routing in routings:
+                exists_res = await db.execute(
+                    select(VasTask).where(
+                        VasTask.order_id == order.id,
+                        VasTask.routing_key == routing.routing_key,
+                        VasTask.script_version == routing.script_version,
+                    )
+                )
+                if exists_res.scalar_one_or_none():
+                    continue
+
+                db.add(
+                    VasTask(
+                        order_id=order.id,
+                        routing_key=routing.routing_key,
+                        script_version=routing.script_version,
+                        priority=10,
+                        status="pending",
+                        user_inputs=order.user_inputs,
+                        config=routing.config,
+                        attempt_count=0,
+                        notify_count=0,
+                        expire_at=datetime.utcnow() + timedelta(days=60),
+                        created_at=datetime.utcnow(),
+                    )
+                )
+
+    # =========================
+    # 工单拒绝逻辑
+    # =========================
+    @staticmethod
+    async def _handle_rejection(
+        db: AsyncSession,
+        ticket: VasTicket,
+    ) -> None:
+
+        if not ticket.order_id:
+            return
+
+        result = await db.execute(
+            select(VasOrder)
+            .where(VasOrder.id == ticket.order_id)
+            .with_for_update()
+        )
+        order = result.scalar_one_or_none()
+        if not order:
+            return
+
+        if ticket.type == "refund" and order.status == "refund_pending":
+            order.status = "paid"
+
     @staticmethod
-    def add_message(
-        db: Session,
+    async def add_message(
+        db: AsyncSession,
         ticket_id: int,
         sender_type: str,   # "user" | "admin" | "system"
-        sender_id: str = None,
+        sender_id: Optional[str] = None,
         content: str = "",
-        attachments: dict = None
+        attachments: Optional[dict] = None,
     ):
-        # 1️⃣ 校验工单是否存在
-        ticket = db.query(VasTicket).filter(
-            VasTicket.id == ticket_id
-        ).first()
-
+        ticket = await db.get(VasTicket, ticket_id)
         if not ticket:
             raise NotFoundError("Ticket not exist")
 
-        # 2️⃣ 创建消息
         message = VasTicketMessage(
             ticket_id=ticket_id,
             sender_type=sender_type,
             sender_id=sender_id,
             content=content,
             attachments=attachments,
-            created_at=datetime.utcnow()
+            created_at=datetime.utcnow(),
         )
 
-        # 3️⃣ 写入数据库
         db.add(message)
-
-        # 4️⃣ 更新工单更新时间(非常重要)
         ticket.updated_at = datetime.utcnow()
 
-        db.commit()
-        db.refresh(message)
-
+        await db.commit()
+        await db.refresh(message)
         return message
-    
+
     @staticmethod
-    def list_messages(
-        db: Session,
+    async def list_messages(
+        db: AsyncSession,
         ticket_id: int,
         page: int = 1,
-        size: int = 20
+        size: int = 20,
     ):
-        # 1️⃣ 校验 ticket 是否存在
-        exists = db.query(VasTicket.id).filter(
-            VasTicket.id == ticket_id
-        ).first()
-
+        # 校验 ticket 是否存在
+        exists = await db.scalar(
+            select(VasTicket.id).where(VasTicket.id == ticket_id)
+        )
         if not exists:
             raise NotFoundError("Ticket not exist")
 
-        # 2️⃣ 查询消息(按时间正序)
-        query = (
-            db.query(VasTicketMessage)
-            .filter(VasTicketMessage.ticket_id == ticket_id)
+        stmt = (
+            select(VasTicketMessage)
+            .where(VasTicketMessage.ticket_id == ticket_id)
             .order_by(VasTicketMessage.created_at.desc())
         )
 
-        return paginate(query, page, size)
-    
+        return await paginate(db, stmt, page, size)
+
     @staticmethod
-    def list_by_user(db: Session, user_id: str, page: int=0, size: int=10, keyword: str=None):    
-        query = db.query(VasTicket).filter_by(user_id=user_id)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_by_user(
+        db: AsyncSession,
+        user_id: str,
+        page: int = 1,
+        size: int = 20,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasTicket).where(VasTicket.user_id == user_id)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasTicket,
             keyword=keyword,
-            fields=["order_id", "user_id", "reason", "admin_comment"]
+            fields=["order_id", "user_id", "reason", "admin_comment"],
         )
-        query = query.order_by(
-            VasTicket.id.desc()
-        )
-        return paginate(query, page, size)  
-      
+
+        stmt = stmt.order_by(VasTicket.id.desc())
+
+        return await paginate(db, stmt, page, size)
+
     @staticmethod
-    def list_all(db: Session, page: int=0, size: int=10, keyword: str=None):    
-        query = db.query(VasTicket)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_all(
+        db: AsyncSession,
+        page: int = 1,
+        size: int = 20,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasTicket)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasTicket,
             keyword=keyword,
-            fields=["order_id", "user_id", "reason", "admin_comment"]
-        )
-        query = query.order_by(
-            VasTicket.id.desc()
+            fields=["order_id", "user_id", "reason", "admin_comment"],
         )
-        return paginate(query, page, size)
+
+        stmt = stmt.order_by(VasTicket.id.desc())
+
+        return await paginate(db, stmt, page, size)

+ 125 - 56
app/services/troov_service.py

@@ -1,80 +1,149 @@
 import json
 import time
-import requests
-from typing import List
-from fastapi import Depends
+import random
+import asyncio
+import aiohttp
+from typing import List, Optional
+
+from redis.asyncio import Redis
+from starlette.concurrency import run_in_threadpool
 from app.schemas.troov import TroovRate
-from app.utils.france_slot_api import *
+from app.utils.france_slot_api import troov_create_session_old
 from app.utils.proxy_utils import load_proxies_from_json
+from app.core.logger import logger
+
 
+# =========================================================
+# Redis 原子弹出 token(不使用 KEYS,避免阻塞)
+# =========================================================
 
-def pop_redis_value_token(redis_client):
-    lua_script = '''
-local keys = redis.call('keys', 'token:*')
+POP_TOKEN_LUA = """
+local cursor = "0"
 local max_ttl = -1
 local max_key = nil
 
-for _, key in ipairs(keys) do
-    local ttl = redis.call('ttl', key)
-    if ttl > 0 and ttl > max_ttl then
-        max_ttl = ttl
-        max_key = key
+repeat
+    local result = redis.call('SCAN', cursor, 'MATCH', 'token:*', 'COUNT', 50)
+    cursor = result[1]
+    local keys = result[2]
+
+    for _, key in ipairs(keys) do
+        local ttl = redis.call('TTL', key)
+        if ttl > max_ttl then
+            max_ttl = ttl
+            max_key = key
+        end
     end
-end
+until cursor == "0"
 
 if max_key then
-    local value = redis.call('get', max_key)
-    redis.call('del', max_key)
+    local value = redis.call('GET', max_key)
+    redis.call('DEL', max_key)
     return {max_key, value, max_ttl}
-else
-    return nil
 end
-'''
-    result = redis_client.eval(lua_script, 0)
-    return result
 
-def fetch_rate(session_dic, date):
-    url = f"https://api.consulat.gouv.fr/api/team/621540d353069dec25bd0045/reservations/availability?name=Visas&date={date}&places=-5&matching=&maxCapacity=-5&sessionId={session_dic['session_id']}"
+return nil
+"""
+
+
+async def pop_redis_value_token(redis_client: Redis):
+    """
+    原子性获取 TTL 最大的 token 并删除
+    """
+    return await redis_client.eval(POP_TOKEN_LUA, 0)
+
+
+# =========================================================
+# 请求法国 Troov 接口(async,不阻塞)
+# =========================================================
+
+async def fetch_rate(session_dic: dict, date: str) -> str:
+    url = (
+        "https://api.consulat.gouv.fr/api/team/"
+        "621540d353069dec25bd0045/reservations/availability"
+        f"?name=Visas&date={date}&places=-5&matching=&maxCapacity=-5"
+        f"&sessionId={session_dic['session_id']}"
+    )
+
     headers = {
-        'accept': 'application/json, text/plain, */*',
-        'accept-language': 'zh-CN,zh;q=0.9,en;q=0.8',
-        'origin': 'https://consulat.gouv.fr',
-        'referer': 'https://consulat.gouv.fr/en/ambassade-de-france-en-irlande/appointment?name=Visas',
-        'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36',
-        'x-gouv-app-id': session_dic['x_gouv_app_id'],
-        'x-gouv-web': 'fr.gouv.consulat'
+        "accept": "application/json, text/plain, */*",
+        "accept-language": "zh-CN,zh;q=0.9,en;q=0.8",
+        "origin": "https://consulat.gouv.fr",
+        "referer": "https://consulat.gouv.fr/en/ambassade-de-france-en-irlande/appointment?name=Visas",
+        "user-agent": (
+            "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
+            "AppleWebKit/537.36 (KHTML, like Gecko) "
+            "Chrome/141.0.0.0 Safari/537.36"
+        ),
+        "x-gouv-app-id": session_dic["x_gouv_app_id"],
+        "x-gouv-web": "fr.gouv.consulat",
     }
-    try:
-        response = requests.get(url, headers=headers)
-        return response.text
-    except Exception as e:
-        return f'Exception info={e}'
 
-def get_rate_by_date(redis_client, date: str) -> List[TroovRate]:
+    timeout = aiohttp.ClientTimeout(total=15)
+
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with session.get(url, headers=headers) as resp:
+            return await resp.text()
+
+
+# =========================================================
+# 核心业务逻辑
+# =========================================================
+
+async def get_rate_by_date(
+    redis_client: Redis,
+    date: str
+) -> Optional[List[TroovRate]]:
     """
-    核心业务逻辑:根据日期返回 Troov 预约信息
+    根据日期获取 Troov 预约信息
     """
-    proxy_pools = ['oxylabs']
+
+    # ---------- 1️⃣ 加载代理 ----------
     proxies = []
-    for pp in proxy_pools:
-        proxies = proxies + load_proxies_from_json("data/proxy_pool_config.json", pp)
-    
-    result = None
-    while True:
-        result = pop_redis_value_token(redis_client)
-        if not result:
-            time.sleep(1)
-            continue
-        break
-    body_str = result[1]
-    body = json.loads(body_str)
-    captcha = body.get("token")
-    
-    session_dic = troov_create_session_old(random.choice(proxies), captcha)
+    for pool in ("oxylabs",):
+        proxies.extend(
+            load_proxies_from_json("data/proxy_pool_config.json", pool)
+        )
+
+    if not proxies:
+        logger.error("Proxy pool is empty")
+        return None
+
+    # ---------- 2️⃣ 获取验证码 token(最多等待 30 秒) ----------
+    token_data = None
+    for _ in range(30):
+        token_data = await pop_redis_value_token(redis_client)
+        if token_data:
+            break
+        await asyncio.sleep(1)
+
+    if not token_data:
+        logger.warning("No captcha token available")
+        return None
+
+    _, body_str, ttl = token_data
+
+    try:
+        body = json.loads(body_str)
+        captcha_token = body.get("token")
+    except Exception:
+        logger.exception("Invalid captcha token format")
+        return None
+
+    # ---------- 3️⃣ 创建 Troov session(同步函数放线程池) ----------
+    proxy = random.choice(proxies)
+
+    session_dic = await run_in_threadpool(troov_create_session_old, proxy, captcha_token)
     if not session_dic:
+        logger.warning("Failed to create Troov session")
         return None
-    logger.info(f'创建 session 成功: {session_dic}')
-    res = fetch_rate(session_dic, date)
-    return json.loads(res)
 
-        
+    logger.info(f"Troov session created: {session_dic}")
+
+    # ---------- 4️⃣ 请求预约数据 ----------
+    try:
+        response_text = await fetch_rate(session_dic, date)
+        return json.loads(response_text)
+    except Exception as e:
+        logger.error(f"Fetch rate failed: {e}")
+        return None

+ 70 - 34
app/services/user_service.py

@@ -1,19 +1,29 @@
 # app/services/user_service.py
+
+import uuid
 from datetime import datetime
-from typing import List
-from sqlalchemy.orm import Session
-from app.utils.search import apply_keyword_search
+from typing import List, Optional
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import NotFoundError
 from app.models.user import VasUser
-from app.schemas.user import VasUserCreate, VasUserUpdate, VasUserSetProfiles, VasUserOut
+from app.schemas.user import (
+    VasUserCreate,
+    VasUserUpdate,
+    VasUserSetProfiles,
+    VasUserOut,
+)
 
 
 class UserService:
-    
+
     @staticmethod
-    def create(db: Session, data: VasUserCreate):
-        uid = f'usr-{uuid.uuid4().hex[:8]}'
+    async def create(db: AsyncSession, data: VasUserCreate) -> VasUser:
+        uid = f"usr-{uuid.uuid4().hex[:8]}"
 
         user = VasUser(
             id=uid,
@@ -22,60 +32,86 @@ class UserService:
             phone=data.phone,
             preferred_language="en",
             timezone="Asia/Shanghai",
+            created_at=datetime.utcnow(),
         )
+
         db.add(user)
-        db.commit()
-        return rec
-    
+        await db.commit()
+        await db.refresh(user)
+
+        return user
+
     @staticmethod
-    def get(db: Session, id: str):
-        user = db.query(VasUser).filter_by(id=id).first()
+    async def get(db: AsyncSession, id: str) -> VasUser:
+        stmt = select(VasUser).where(VasUser.id == id)
+        result = await db.execute(stmt)
+        user = result.scalar_one_or_none()
+
         if not user:
             raise NotFoundError("User not exist")
+
         return user
-    
+
     @staticmethod
-    def update(db: Session, id: str, payload: VasUserUpdate):
-        obj = db.query(VasUser).filter(VasUser.id == id).first()
+    async def update(
+        db: AsyncSession,
+        id: str,
+        payload: VasUserUpdate
+    ) -> VasUser:
+        stmt = select(VasUser).where(VasUser.id == id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("User not exist")
-        data = payload.dict(exclude_unset=True)  # ⭐ 关键
+
+        data = payload.dict(exclude_unset=True)
 
         for key, value in data.items():
             setattr(obj, key, value)
-        db.commit()
-        db.refresh(obj)
+
+        obj.updated_at = datetime.utcnow()
+
+        await db.commit()
+        await db.refresh(obj)
+
         return obj
-    
+
     @staticmethod
-    def set_profiles(db: Session, user: VasUser, payload: VasUserSetProfiles):
+    async def set_profiles(
+        db: AsyncSession,
+        user: VasUser,
+        payload: VasUserSetProfiles
+    ) -> VasUser:
         """
         更新用户资料(profile)
         """
 
-        # 1️⃣ 字段赋值(显式,避免误更新)
         user.phone = payload.phone
         user.nickname = payload.nickname
         user.avatar_url = payload.avatar_url
-
-        # 2️⃣ 更新时间(可选,SQLAlchemy onupdate 也会生效)
         user.updated_at = datetime.utcnow()
 
-        # 3️⃣ 持久化
         db.add(user)
-        db.commit()
-        db.refresh(user)
+        await db.commit()
+        await db.refresh(user)
 
         return user
 
     @staticmethod
-    def list_all(db: Session, page: int = 1, size: int = 20, keyword: str=None):
-        query = db.query(VasUser)
-        
-        query = apply_keyword_search(
-            query=query,
+    async def list_all(
+        db: AsyncSession,
+        page: int = 1,
+        size: int = 20,
+        keyword: Optional[str] = None,
+    ):
+        stmt = select(VasUser)
+
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasUser,
             keyword=keyword,
-            fields=["id", "email", "nickname", "phone"]
+            fields=["id", "email", "nickname", "phone"],
         )
-        return paginate(query, page, size)
+
+        return await paginate(db, stmt, page, size)

+ 103 - 53
app/services/vas_task_service.py

@@ -1,88 +1,138 @@
 # app/services/task_service.py
-from sqlalchemy.orm import Session
-from typing import List
-from app.utils.search import apply_keyword_search
+
+from datetime import datetime
+from typing import List, Optional
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.utils.search import apply_keyword_search_stmt
 from app.utils.pagination import paginate
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from app.core.biz_exception import NotFoundError
 from app.models.vas_task import VasTask
+from app.models.order import VasOrder
 from app.schemas.vas_task import VasTaskCreate, VasTaskUpdate
-from datetime import datetime
+
 
 class VasTaskService:
 
-    def create(db: Session, data: VasTaskCreate):
-        rec = VasTask(**data.dict(), status='pending', created_at=datetime.utcnow())
+    @staticmethod
+    async def create(db: AsyncSession, data: VasTaskCreate) -> VasTask:
+        rec = VasTask(
+            **data.dict(),
+            status="pending",
+            created_at=datetime.utcnow(),
+        )
         db.add(rec)
-        db.commit()
-        db.refresh(rec)
+        await db.commit()
+        await db.refresh(rec)
         return rec
 
-    def list_task(
-        db: Session,
-        status: str = None,
-        routing_key: str = None,
-        script_version: str = None,
-        keyword: str = None,
+    @staticmethod
+    async def list_task(
+        db: AsyncSession,
+        status: Optional[str] = None,
+        routing_key: Optional[str] = None,
+        script_version: Optional[str] = None,
+        keyword: Optional[str] = None,
         page: int = 0,
         size: int = 10,
     ):
-        query = db.query(VasTask)
-        
+        stmt = select(VasTask)
+
         if status:
-            query = query.filter(VasTask.status == status)
-        
+            stmt = stmt.where(VasTask.status == status)
+
         if routing_key:
-            query = query.filter(VasTask.routing_key == routing_key)
+            stmt = stmt.where(VasTask.routing_key == routing_key)
 
         if script_version:
-            query = query.filter(VasTask.script_version == script_version)
+            stmt = stmt.where(VasTask.script_version == script_version)
 
-        query = apply_keyword_search(
-            query=query,
+        stmt = apply_keyword_search_stmt(
+            stmt=stmt,
             model=VasTask,
             keyword=keyword,
-            fields=["order_id", "routing_key", "user_inputs"]
+            fields=["order_id", "routing_key", "user_inputs"],
         )
-        
-        query = query.order_by(
+
+        stmt = stmt.order_by(
             VasTask.priority.desc(),
-            VasTask.id.asc()
+            VasTask.id.asc(),
         )
-        return paginate(query, page, size)
-    
-    def update(db: Session, id: int, payload: VasTaskUpdate):
-        obj = db.query(VasTask).filter(VasTask.id == id).first()
+
+        return await paginate(db, stmt, page, size)
+
+    @staticmethod
+    async def update(
+        db: AsyncSession,
+        id: int,
+        payload: VasTaskUpdate,
+    ) -> VasTask:
+        stmt = select(VasTask).where(VasTask.id == id)
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
         if not obj:
             raise NotFoundError("Task not exist")
-        data = payload.dict(exclude_unset=True)  # ⭐ 关键
+
+        data = payload.dict(exclude_unset=True)
 
         for key, value in data.items():
             setattr(obj, key, value)
-        db.commit()
-        db.refresh(obj)
+
+        await db.commit()
+        await db.refresh(obj)
         return obj
-        
-    def get_active_task_by_order_id(db: Session, order_id:str):
-        recs = db.query(VasTask).filter(
+
+    @staticmethod
+    async def get_active_task_by_order_id(
+        db: AsyncSession,
+        order_id: str,
+    ) -> List[VasTask]:
+        stmt = select(VasTask).where(
             VasTask.status == "pending",
-            VasTask.order_id == order_id
-        ).all()
-        return recs
-    
-    def return_to_queue(db: Session, id:int):
-        rec = db.query(VasTask).filter_by(id=id).first()
+            VasTask.order_id == order_id,
+        )
+        result = await db.execute(stmt)
+        return result.scalars().all()
+
+    @staticmethod
+    async def return_to_queue(db: AsyncSession, id: int) -> VasTask:
+        stmt = select(VasTask).where(VasTask.id == id)
+        result = await db.execute(stmt)
+        rec = result.scalar_one_or_none()
+
         if not rec:
             raise NotFoundError("Task not exist")
-        rec.status = 'pending'
-        db.commit()
-        db.refresh(rec)
+
+        rec.status = "pending"
+        rec.attempt_count = (rec.attempt_count or 0) + 1
+
+        await db.commit()
+        await db.refresh(rec)
         return rec
 
-    def manual_confirm(db: Session, id:int):
-        rec = db.query(VasTask).filter_by(id=id).first()
-        if not rec:
+    @staticmethod
+    async def manual_confirm(db: AsyncSession, id: int) -> VasTask:
+        stmt = select(VasTask).where(VasTask.id == id)
+        result = await db.execute(stmt)
+        task = result.scalar_one_or_none()
+
+        if not task:
             raise NotFoundError("Task not exist")
-        rec.status = 'completed'
-        db.commit()
-        db.refresh(rec)
-        return rec
+
+        task.status = "completed"
+
+        order_stmt = select(VasOrder).where(VasOrder.id == task.order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
+        if not order:
+            raise NotFoundError("Order not exist")
+
+        order.status = "completed"
+
+        await db.commit()
+        await db.refresh(task)
+        return task

+ 113 - 116
app/services/webhook_service.py

@@ -1,61 +1,62 @@
 import re
 import json
 from datetime import datetime, timedelta
-from sqlalchemy.orm import Session
 from typing import List, Optional
-from decimal import Decimal, ROUND_HALF_UP
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+from decimal import Decimal
+
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+
+from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.order import VasOrder
 from app.models.vas_task import VasTask
-from app.models.product import VasProduct
 from app.models.product_routing import VasProductRouting
 from app.models.payment_event import VasPaymentEvent
 from app.models.payment import VasPayment
 from app.models.payment_qr import VasPaymentQR
 from app.schemas.webhook import SMSHelperWebhookPayload, PaymentWebhookOut
 
+
 class WebhookService:
-    
+
+    # =========================================================
+    # 内部方法:创建 Task(幂等)
+    # =========================================================
     @staticmethod
-    def _create_task_if_not_exists(
-        db: Session,
-        order: VasOrder
-    ):
-        routings = (
-            db.query(VasProductRouting)
-            .filter(
-                VasProductRouting.product_id == order.product_id,
-                VasProductRouting.is_active == 1
-            )
-            .all()
+    async def _create_task_if_not_exists(
+        db: AsyncSession,
+        order: VasOrder,
+    ) -> List[VasTask]:
+
+        stmt = select(VasProductRouting).where(
+            VasProductRouting.product_id == order.product_id,
+            VasProductRouting.is_active == 1,
         )
+        result = await db.execute(stmt)
+        routings = result.scalars().all()
 
         if not routings:
             return []
 
-        created_tasks = []
+        created_tasks: List[VasTask] = []
 
         for routing in routings:
-
-            # ---------- 2. 幂等判断 ----------
-            exists = (
-                db.query(VasTask)
-                .filter(
-                    VasTask.order_id == order.id,
-                    VasTask.routing_key == routing.routing_key,
-                    VasTask.script_version == routing.script_version,
-                )
-                .first()
+            exists_stmt = select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.routing_key == routing.routing_key,
+                VasTask.script_version == routing.script_version,
             )
+            exists_result = await db.execute(exists_stmt)
+            exists = exists_result.scalar_one_or_none()
+
             if exists:
                 continue
 
-            # ---------- 3. 创建 task ----------
             task = VasTask(
                 order_id=order.id,
                 routing_key=routing.routing_key,
                 script_version=routing.script_version,
-                priority=10,
+                priority=routing.priority,
                 status="pending",
                 user_inputs=order.user_inputs,
                 config=routing.config,
@@ -66,24 +67,27 @@ class WebhookService:
             )
             db.add(task)
             created_tasks.append(task)
-            
+
         return created_tasks
 
-    
+    # =========================================================
+    # SMSHelper 微信 / 支付宝 收款 webhook
+    # =========================================================
     @staticmethod
-    def smshelper_payment_webhook(db: Session, payload: SMSHelperWebhookPayload):
-        """
-        webhook payload 示例:
-        title=【微信】微信支付
-        content=【SM-E5260】个人收款码到账¥0.01
-        """
+    async def smshelper_payment_webhook(
+        db: AsyncSession,
+        payload: SMSHelperWebhookPayload,
+    ) -> Optional[PaymentWebhookOut]:
 
         title = payload.title
         content = payload.content
+
         if "微信" in title:
             provider = "wechat"
         elif "支付宝" in title:
             provider = "alipay"
+        else:
+            raise BizLogicError("Unknown payment provider")
 
         device_match = re.search(r"【(.+?)】", content)
         device_id = device_match.group(1) if device_match else None
@@ -94,7 +98,7 @@ class WebhookService:
 
         amount_yuan = Decimal(amount_match.group(1))
         amount_cent = int(amount_yuan * 100)
-        
+
         event = VasPaymentEvent(
             provider=provider,
             event_type="payment_received",
@@ -104,59 +108,57 @@ class WebhookService:
             parsed_currency="CNY",
             parsed_device=device_id,
             raw_payload=payload.dict(),
-            status="received"
+            status="received",
         )
         db.add(event)
-        db.commit()
-        db.refresh(event)
-        
-        payment_qr = (
-            db.query(VasPaymentQR)
-            .filter(
-                VasPaymentQR.provider == provider,
-                VasPaymentQR.device == device_id,
-                VasPaymentQR.is_active == 1
-            )
-            .first()
+        await db.commit()
+        await db.refresh(event)
+
+        # ---------- 查找 QR ----------
+        qr_stmt = select(VasPaymentQR).where(
+            VasPaymentQR.provider == provider,
+            VasPaymentQR.device == device_id,
+            VasPaymentQR.is_active == 1,
         )
-        
+        qr_result = await db.execute(qr_stmt)
+        payment_qr = qr_result.scalar_one_or_none()
+
         if not payment_qr:
             event.status = "failed"
             event.error_message = "QR not found"
-            db.commit()
+            await db.commit()
             raise BizLogicError("QR not found")
 
-        payment = (
-            db.query(VasPayment)
-            .filter(
+        # ---------- 查找 payment ----------
+        pay_stmt = (
+            select(VasPayment)
+            .where(
                 VasPayment.provider == provider,
                 VasPayment.amount == amount_cent,
                 VasPayment.qr_id == payment_qr.id,
-                VasPayment.status == "pending"
+                VasPayment.status == "pending",
             )
             .order_by(VasPayment.created_at.desc())
-            .first()
         )
-        
+        pay_result = await db.execute(pay_stmt)
+        payment = pay_result.scalar_one_or_none()
+
         if not payment:
             event.status = "failed"
             event.error_message = "No matching pending payment"
-            db.commit()
+            await db.commit()
             raise BizLogicError("Payment not found")
+
         if payment.status in ("succeeded", "late_paid"):
             event.status = "duplicate"
             event.matched_payment_id = payment.id
             event.matched_order_id = payment.order_id
-            db.commit()
+            await db.commit()
             return None
 
         now = datetime.utcnow()
-        if payment.expire_at and now > payment.expire_at:
-            payment.status = "late_paid"
-        else:
-            payment.status = "succeeded"
-                            
-        # ---------- 写入原始 payload ----------
+        payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
+
         payment.provider_payload = {
             "title": title,
             "content": content,
@@ -164,47 +166,51 @@ class WebhookService:
             "received_at": now.isoformat(),
         }
 
-        order = db.query(VasOrder).filter(VasOrder.id == payment.order_id).first()
+        order_stmt = select(VasOrder).where(VasOrder.id == payment.order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
         if order and order.status != "paid":
             order.status = "paid"
-            
-        WebhookService._create_task_if_not_exists(db, order)
+
+        await WebhookService._create_task_if_not_exists(db, order)
 
         event.status = "applied"
         event.matched_payment_id = payment.id
         event.matched_order_id = payment.order_id
 
-        db.commit()
-        db.refresh(payment)
-        
+        await db.commit()
+        await db.refresh(payment)
+
         return PaymentWebhookOut(
             status=True,
             order_id=order.id,
             user_id=order.user_id,
             payment_id=payment.id,
-            notify=True
+            notify=True,
         )
-        
+
+    # =========================================================
+    # Stripe webhook
+    # =========================================================
     @staticmethod
-    def stripe_payment_webhook(db: Session, event):
-        """
-        Stripe webhook handler
-        """
+    async def stripe_payment_webhook(
+        db: AsyncSession,
+        event: dict,
+    ) -> Optional[PaymentWebhookOut]:
 
         event_id = event["id"]
         event_type = event["type"]
         data = event["data"]["object"]
-        # ---------- 1. 幂等(事件级) ----------
-        existed_event = (
-            db.query(VasPaymentEvent)
-            .filter(VasPaymentEvent.provider == "stripe")
-            .filter(VasPaymentEvent.event_id == event_id)
-            .first()
+
+        existed_stmt = select(VasPaymentEvent).where(
+            VasPaymentEvent.provider == "stripe",
+            VasPaymentEvent.event_id == event_id,
         )
-        if existed_event:
+        existed_result = await db.execute(existed_stmt)
+        if existed_result.scalar_one_or_none():
             return None
 
-        # ---------- 2. 只处理关心的事件 ----------
         if event_type != "checkout.session.completed":
             db.add(
                 VasPaymentEvent(
@@ -215,10 +221,9 @@ class WebhookService:
                     created_at=datetime.utcnow(),
                 )
             )
-            db.commit()
+            await db.commit()
             return None
 
-        # ---------- 3. 解析 metadata ----------
         metadata = data.get("metadata", {})
         payment_id = metadata.get("payment_id")
         order_id = metadata.get("order_id")
@@ -226,56 +231,50 @@ class WebhookService:
         if not payment_id or not order_id:
             raise BizLogicError("Missing payment_id or order_id in metadata")
 
-        # ---------- 4. 查找 payment(业务级幂等) ----------
-        payment = (
-            db.query(VasPayment)
-            .filter(VasPayment.id == int(payment_id))
-            .first()
-        )
+        pay_stmt = select(VasPayment).where(VasPayment.id == int(payment_id))
+        pay_result = await db.execute(pay_stmt)
+        payment = pay_result.scalar_one_or_none()
+
         if not payment:
             raise NotFoundError("Payment not found")
 
         if payment.status == "succeeded":
-            # 已处理过
             db.add(
                 VasPaymentEvent(
                     provider="stripe",
                     event_id=event_id,
                     event_type=event_type,
-                    payload=event,
                     payment_id=payment.id,
                     created_at=datetime.utcnow(),
                 )
             )
-            db.commit()
+            await db.commit()
             return None
 
-        # ---------- 5. 金额校验 ----------
-        paid_amount = data["amount_total"]  # 单位:cent
+        paid_amount = data["amount_total"]
         paid_currency = data["currency"].upper()
 
         if paid_amount != payment.amount or paid_currency != payment.currency:
-            raise BizLogicError(f"Amount mismatch, expected {payment.amount} {payment.currency}, got {paid_amount} {paid_currency}")
+            raise BizLogicError(
+                f"Amount mismatch, expected {payment.amount} {payment.currency}, "
+                f"got {paid_amount} {paid_currency}"
+            )
 
-        # ---------- 6. 判断是否超时 ----------
         now = datetime.utcnow()
-        if payment.expire_at and now > payment.expire_at:
-            payment.status = "late_paid"
-        else:
-            payment.status = "succeeded"
-            
+        payment.status = "late_paid" if payment.expire_at and now > payment.expire_at else "succeeded"
         payment.provider_payload = event
         payment.updated_at = now
 
-        # ---------- 7. 更新 order ----------
-        order = db.query(VasOrder).filter(VasOrder.id == order_id).first()
+        order_stmt = select(VasOrder).where(VasOrder.id == order_id)
+        order_result = await db.execute(order_stmt)
+        order = order_result.scalar_one_or_none()
+
         if order and order.status != "paid":
             order.status = "paid"
             order.updated_at = now
-            
-        WebhookService._create_task_if_not_exists(db, order)
 
-        # ---------- 8. 写 payment_event ----------
+        await WebhookService._create_task_if_not_exists(db, order)
+
         db.add(
             VasPaymentEvent(
                 provider="stripe",
@@ -288,15 +287,13 @@ class WebhookService:
             )
         )
 
-        db.commit()
-        db.refresh(payment)
+        await db.commit()
+        await db.refresh(payment)
 
         return PaymentWebhookOut(
             status=True,
             order_id=order.id,
             user_id=order.user_id,
             payment_id=payment.id,
-            notify=True
+            notify=True,
         )
-
-   

+ 31 - 12
app/services/wechat_service.py

@@ -1,23 +1,42 @@
-import json
-import time
-import requests
-from typing import List
-from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+import httpx
+from app.core.biz_exception import BizLogicError
 from app.schemas.wechat import WechatIn
 
 
 class WechatService:
-    def push_to_wechat(payload: WechatIn):
+    @staticmethod
+    async def push_to_wechat(payload: WechatIn):
         """
-        企业微信 WebHook 格式:
+        企业微信 WebHook 推送(Async 版)
+
+        WebHook 格式:
         https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=YOUR_KEY
         """
         url = f"https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key={payload.api_token}"
-        payload = {"msgtype": "text", "text": {"content": payload.message}}
 
-        response = requests.post(url, json=payload, timeout=10)
+        body = {
+            "msgtype": "text",
+            "text": {
+                "content": payload.message
+            }
+        }
+
+        try:
+            async with httpx.AsyncClient(timeout=10) as client:
+                response = await client.post(url, json=body)
+
+        except httpx.RequestError as e:
+            raise BizLogicError(f"Wechat push request error: {e}")
+
+        if response.status_code != 200:
+            raise BizLogicError(
+                f"Wechat push failed, http_status={response.status_code}"
+            )
+
         data = response.json()
+        if data.get("errcode") != 0:
+            raise BizLogicError(
+                f"Wechat push failed, errcode={data.get('errcode')}, errmsg={data.get('errmsg')}"
+            )
 
-        if response.status_code != 200 or data.get("errcode") != 0:
-            # logger.error(f"企业微信推送失败: {response.text}")
-            raise BizLogicError("Wechat push failed")
+        return True

+ 256 - 0
app/tasks/notification_task.py

@@ -0,0 +1,256 @@
+
+import asyncio
+from typing import Dict, Any
+from redis.asyncio import Redis
+from app.services.wechat_service import WechatService
+from app.services.email_authorizations_service import EmailAuthorizationService
+from app.utils.redis_utils import redis_qpop
+
+async def notification_consumer(redis_client: Redis):
+    """
+    异步消费 Redis 队列 vas_notification_queue
+    """
+    queue_name = "vas_notification_queue"
+    return
+    while True:
+        try:
+            # 阻塞获取队列消息
+            message: Dict[str, Any] = await redis_qpop(redis_client, queue_name, timeout=5)
+            if not message:
+                await asyncio.sleep(1)  # 队列为空,休眠
+                continue
+
+            channels = message.get("channels", [])
+            template_id = message.get("template_id")
+            payload = message.get("payload", {})
+            user_id = message.get("user_id")
+
+            # 按渠道发送
+            if "email" in channels:
+                # EmailService.create(user_id, template_id, payload) 是你自己实现的发送逻辑
+                await EmailAuthorizationService.send(user_id=user_id, template_id=template_id, payload=payload)
+
+            if "wechat" in channels:
+                api_token = payload.get("api_token")
+                content = payload.get("message") or payload.get("content")
+                if api_token and content:
+                    await WechatService.push_to_wechat({"api_token": api_token, "message": content})
+
+            print(f"✅ Notification sent: {message.get('notification_id')}")
+
+        except Exception as e:
+            print(f"⚠️ Notification consumer error: {e}")
+            await asyncio.sleep(1)  # 避免异常循环过快
+
+def template_for_bind_email(payload):
+    
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Email Verification</title>
+        <style>
+            body { font-family: Arial, sans-serif; background-color: #f4f4f4; margin: 0; padding: 0; }
+            .container { max-width: 600px; margin: 20px auto; background-color: #ffffff; padding: 30px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }
+            .code { font-size: 24px; font-weight: bold; letter-spacing: 5px; color: #333; background-color: #f0f0f0; padding: 15px; text-align: center; border-radius: 4px; margin: 20px 0; }
+            .footer { font-size: 12px; color: #888; margin-top: 30px; text-align: center; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <h2>Verify your email address</h2>
+            <p>Hello,</p>
+            <p>You requested to bind this email address to your <strong>{{app_name}}</strong> account. Please use the verification code below to proceed:</p>
+            
+            <div class="code">{{code}}</div>
+            
+            <p>This code will expire in <strong>{{expiration_time}}</strong>.</p>
+            <p>If you did not request this change, please ignore this email.</p>
+            
+            <br>
+            <p>Best regards,<br>The {{app_name}} Team</p>
+            
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+
+def template_for_reset_pwd(payload):
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Reset Password</title>
+        <style>
+            body { font-family: 'Helvetica Neue', Arial, sans-serif; background-color: #f9f9f9; margin: 0; padding: 0; color: #333; }
+            .container { max-width: 600px; margin: 40px auto; background-color: #ffffff; padding: 40px; border-radius: 8px; box-shadow: 0 4px 10px rgba(0,0,0,0.05); }
+            .header { border-bottom: 1px solid #eee; padding-bottom: 20px; margin-bottom: 30px; }
+            .header h2 { margin: 0; color: #333; }
+            .code { font-size: 32px; font-weight: bold; letter-spacing: 5px; color: #2563eb; background-color: #eff6ff; padding: 20px; text-align: center; border-radius: 8px; margin: 30px 0; border: 1px solid #dbeafe; }
+            .warning { background-color: #fff7ed; border-left: 4px solid #f97316; padding: 15px; font-size: 14px; color: #c2410c; margin-top: 30px; }
+            .footer { font-size: 12px; color: #999; margin-top: 40px; text-align: center; border-top: 1px solid #eee; padding-top: 20px; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <div class="header">
+                <h2>Password Reset Request</h2>
+            </div>
+            
+            <p>Hello,</p>
+            <p>We received a request to reset the password for your <strong>{{app_name}}</strong> account. Please use the following code to verify your identity:</p>
+            
+            <div class="code">{{code}}</div>
+            
+            <p>This code is valid for <strong>{{expiration_time}}</strong>.</p>
+            
+            <div class="warning">
+                <strong>Security Tip:</strong> If you did not request a password reset, please ignore this email. No changes will be made to your account.
+            </div>
+            
+            <br>
+            <p>Best regards,<br>The {{app_name}} Team</p>
+            
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.<br>
+                This is an automated message, please do not reply.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+    
+def template_for_login_credentials(payload):
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Your Account Details</title>
+        <style>
+            body { font-family: 'Helvetica Neue', Arial, sans-serif; background-color: #f4f6f8; margin: 0; padding: 0; color: #333; }
+            .container { max-width: 600px; margin: 40px auto; background-color: #ffffff; padding: 40px; border-radius: 8px; box-shadow: 0 4px 12px rgba(0,0,0,0.05); }
+            .header { text-align: center; border-bottom: 1px solid #eee; padding-bottom: 20px; margin-bottom: 30px; }
+            .header h1 { font-size: 24px; color: #1a1a1a; margin: 0; }
+            .creds-box { background-color: #f0f7ff; border: 1px solid #dbeafe; border-radius: 8px; padding: 20px; margin: 25px 0; }
+            .creds-item { margin-bottom: 10px; font-size: 16px; }
+            .creds-label { font-weight: bold; color: #555; width: 100px; display: inline-block; }
+            .creds-value { font-family: 'Courier New', Courier, monospace; font-weight: bold; color: #2563eb; }
+            .btn { display: block; width: 200px; margin: 30px auto; background-color: #2563eb; color: #ffffff !important; text-align: center; padding: 12px 0; border-radius: 6px; text-decoration: none; font-weight: bold; }
+            .note { font-size: 13px; color: #666; background-color: #fff4e5; padding: 10px; border-radius: 4px; border-left: 4px solid #f97316; }
+            .footer { font-size: 12px; color: #999; margin-top: 40px; text-align: center; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <div class="header">
+                <h1>Welcome to {{app_name}}</h1>
+            </div>
+            
+            <p>Dear User,</p>
+            <p>Your account has been successfully set up. Below are your temporary login credentials.</p>
+            
+            <div class="creds-box">
+                <div class="creds-item">
+                    <span class="creds-label">Username:</span>
+                    <span class="creds-value">{{username}}</span>
+                </div>
+                <div class="creds-item">
+                    <span class="creds-label">Password:</span>
+                    <span class="creds-value">{{password}}</span>
+                </div>
+            </div>
+
+            <div class="note">
+                <strong>Important:</strong> For your security, please change your password immediately after logging in.
+            </div>
+
+            <a href="{{login_url}}" class="btn">Log In Now</a>
+            
+            <p style="text-align: center; font-size: 14px;">
+                Or copy this link: <a href="{{login_url}}">{{login_url}}</a>
+            </p>
+
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.<br>
+                If you did not request this account, please contact support.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+    
+def template_ticket_open(payload):
+    template = '''
+    <!DOCTYPE html>
+    <html>
+    <head>
+        <meta charset="UTF-8">
+        <title>Ticket Created</title>
+        <style>
+            body { font-family: 'Helvetica Neue', Arial, sans-serif; background-color: #f4f6f8; margin: 0; padding: 0; color: #333; }
+            .container { max-width: 600px; margin: 40px auto; background-color: #ffffff; padding: 40px; border-radius: 8px; box-shadow: 0 4px 12px rgba(0,0,0,0.05); }
+            .header { border-bottom: 1px solid #eee; padding-bottom: 20px; margin-bottom: 30px; }
+            .header h1 { font-size: 22px; color: #1a1a1a; margin: 0; }
+            .ticket-info { background-color: #f8fafc; border: 1px solid #e2e8f0; border-radius: 8px; padding: 20px; margin: 20px 0; }
+            .info-row { margin-bottom: 10px; display: flex; justify-content: space-between; }
+            .info-label { color: #64748b; font-size: 14px; }
+            .info-value { font-weight: bold; color: #0f172a; font-size: 14px; }
+            .btn { display: block; width: 200px; margin: 30px auto; background-color: #2563eb; color: #ffffff !important; text-align: center; padding: 12px 0; border-radius: 6px; text-decoration: none; font-weight: bold; font-size: 14px; }
+            .footer { font-size: 12px; color: #94a3b8; margin-top: 40px; text-align: center; }
+        </style>
+    </head>
+    <body>
+        <div class="container">
+            <div class="header">
+                <h1>Support Request Received</h1>
+            </div>
+            
+            <p>Hello {{username}},</p>
+            <p>We wanted to let you know that we've received your request. Our team is currently reviewing the details.</p>
+            
+            <div class="ticket-info">
+                <div class="info-row">
+                    <span class="info-label">Ticket ID:</span>
+                    <span class="info-value">#{{ticket_id}}</span>
+                </div>
+                <div class="info-row">
+                    <span class="info-label">Type:</span>
+                    <span class="info-value">{{ticket_type}}</span>
+                </div>
+                <div class="info-row" style="margin-bottom: 0;">
+                    <span class="info-label">Time:</span>
+                    <span class="info-value">{{created_at}}</span>
+                </div>
+            </div>
+
+            <p>We usually reply within 24 hours. You will receive an email notification when our agent replies.</p>
+
+            <a href="{{ticket_url}}" class="btn">View Ticket Details</a>
+            
+            <div class="footer">
+                &copy; 2025 {{app_name}}. All rights reserved.<br>
+                Please do not reply to this email directly.
+            </div>
+        </div>
+    </body>
+    </html>
+    '''
+    
+def template_confirm_payment(payload):
+    template = {
+        "touser": "ADMIN_USER_ID",
+        "msgtype": "textcard",
+        "agentid": 1000001,
+        "textcard": {
+            "title": "💰 待确认:收到新的手动转账",
+            "description": "<div class=\"gray\">2025-12-31 10:30:00</div> <br>订单号:ORD-20251231-001<br>用户:user@example.com<br><div class=\"highlight\">金额:¥ 3,500.00</div><br>请核实资金到账后,点击卡片确认收款。",
+            "url": "https://admin.visafly.com/payment/confirm?payment_id=123&token=secure_token_abc",
+            "btntxt": "立即确认"
+        }
+    }

+ 18 - 9
app/utils/pagination.py

@@ -1,24 +1,33 @@
-from sqlalchemy.orm import Query
-from sqlalchemy.orm import Session
+from typing import Any, Dict
+from sqlalchemy import select, func
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.sql import Select
 
-def paginate(
-    query: Query,
+
+async def paginate(
+    db: AsyncSession,
+    stmt: Select,
     page: int = 1,
     size: int = 20,
-):
+) -> Dict[str, Any]:
     if page < 1:
         page = 1
     if size < 1:
         size = 20
 
-    total = query.count()
+    # ---------- 1️⃣ 查询总数 ----------
+    count_stmt = select(func.count()).select_from(
+        stmt.subquery()
+    )
+    total = await db.scalar(count_stmt) or 0
 
-    items = (
-        query
+    # ---------- 2️⃣ 查询分页数据 ----------
+    result = await db.execute(
+        stmt
         .offset((page - 1) * size)
         .limit(size)
-        .all()
     )
+    items = result.scalars().all()
 
     return {
         "items": items,

+ 28 - 14
app/utils/redis_utils.py

@@ -1,22 +1,36 @@
 import json
-import pytz
-import resource
+from typing import Optional
 from redis.asyncio import Redis
-from datetime import datetime, timedelta
 
 
-def redis_qpush(redis_client, qname: str, data: dict, max_len: int = 30):
-    """向队列右侧推入数据,并限制队列最大长度"""
+async def redis_qpush(
+    redis_client: Redis,
+    qname: str,
+    data: dict,
+    max_len: int = 30,
+):
+    """
+    向队列右侧推入数据,并限制队列最大长度(Async 版)
+    """
     data_string = json.dumps(data)
-    pipe = redis_client.pipeline()
+
+    pipe = redis_client.pipeline(transaction=True)
     pipe.rpush(qname, data_string)
-    pipe.ltrim(qname, -max_len, -1)  # 只保留右侧 max_len 个元素
-    pipe.execute()
+    pipe.ltrim(qname, -max_len, -1)
+    await pipe.execute()
+
 
-def redis_qpop(redis_client, qname:str, timeout: int = 5):
-    message = redis_client.blpop(qname, timeout=timeout)
+async def redis_qpop(
+    redis_client: Redis,
+    qname: str,
+    timeout: int = 5,
+) -> Optional[dict]:
+    """
+    从队列左侧阻塞弹出数据(Async 版)
+    """
+    message = await redis_client.blpop(qname, timeout=timeout)
     if message is None:
-        return None  # 队列为空,直接返回
-    message_string = message[1]
-    data = json.loads(message_string)
-    return data
+        return None
+
+    _, message_string = message
+    return json.loads(message_string)

+ 18 - 4
app/utils/search.py

@@ -1,15 +1,29 @@
+from typing import List, Optional
 from sqlalchemy import or_
-from typing import List
+from sqlalchemy.sql import Select
 
-def apply_keyword_search(query, model, keyword: str, fields: List[str]):
+
+def apply_keyword_search_stmt(
+    stmt: Select,
+    model,
+    keyword: Optional[str],
+    fields: List[str],
+) -> Select:
+    """
+    Async / SQLAlchemy 2.0 Select 版本的关键字搜索
+    """
     if not keyword:
-        return query
+        return stmt
 
     like = f"%{keyword}%"
 
     conditions = [
         getattr(model, field).ilike(like)
         for field in fields
+        if hasattr(model, field)
     ]
 
-    return query.filter(or_(*conditions))
+    if not conditions:
+        return stmt
+
+    return stmt.where(or_(*conditions))

+ 1 - 1
app/utils/validation_utils.py

@@ -2,7 +2,7 @@ from jsonschema import validate, ValidationError
 from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
 
 def validate_user_inputs(schema_json: dict, user_inputs: dict):
-    print(f'schema_json={schema_json}, user_inputs={user_inputs}')
+    # print(f'schema_json={schema_json}, user_inputs={user_inputs}')
     try:
         validate(instance=user_inputs, schema=schema_json)
     except ValidationError as e:

+ 11 - 0
app/utils/wrappers.py

@@ -0,0 +1,11 @@
+from functools import wraps
+
+def after_call(after_func):
+    def decorator(func):
+        @wraps(func)
+        def wrapper(*args, **kwargs):
+            result = func(*args, **kwargs)
+            after_func(*args, **kwargs)
+            return result
+        return wrapper
+    return decorator

+ 52 - 13
starter.py

@@ -1,35 +1,74 @@
 #!/usr/bin/env python3
 import os
-import sys
 import subprocess
 
+
 def main():
     """
     启动 FastAPI 应用,根据环境选择不同参数:
-    - DEV: 热重载,单进程
+    - DEV: 热重载,单进程,仅监听 app 目录
     - PROD: 多进程,高性能 uvloop + httptools
     """
-    env = os.getenv("ENV", "DEV").upper()  # 默认开发环境
+    
+    os.environ.setdefault(
+        "WATCHFILES_IGNORE",
+        "**/.git/**,**/venv/**,**/__pycache__/**"
+    )
+    
+    env = os.getenv("ENV", "DEV").upper()
+    
+    env = "DEV"
+
     host = "0.0.0.0"
     port = "8888"
     app_module = "app.main:app"
 
-    base_cmd = ["uvicorn", app_module, "--host", host, "--port", port]
+    base_cmd = [
+        "uvicorn",
+        app_module,
+        "--host", host,
+        "--port", port,
+    ]
 
     if env == "DEV":
-        print("启动开发环境(热重载)...")
-        base_cmd += ["--reload", "--workers", "1"]
+        print("🚀 启动开发环境(热重载)")
+
+        base_cmd += [
+            "--reload",
+            "--workers", "1",
+
+            # ⭐ 关键:限制监听范围
+            "--reload-dir", "app",
+            "--reload-exclude", ".git",
+            "--reload-exclude", "venv",
+            "--reload-exclude", "__pycache__",
+        ]
+
     elif env == "PROD":
-        print("启动生产环境(多进程 + 高性能)...")
-        base_cmd += ["--workers", str(os.cpu_count()), "--loop", "uvloop", "--http", "httptools"]
+        print("🔥 启动生产环境(多进程 + 高性能)")
+
+        base_cmd += [
+            "--workers", str(os.cpu_count() or 1),
+            "--loop", "uvloop",
+            "--http", "httptools",
+        ]
+
     else:
-        print(f"未知环境 {env},使用默认开发配置")
-        base_cmd += ["--reload", "--workers", "1"]
+        print(f"⚠️ 未知环境 {env},使用默认 DEV 配置")
+
+        base_cmd += [
+            "--reload",
+            "--workers", "1",
+            "--reload-dir", "app",
+        ]
+
+    print("\n执行命令:")
+    print(" ", " ".join(base_cmd))
+    print(f"\nSwagger UI: http://{host}:{port}/docs")
+    print(f"ReDoc UI:   http://{host}:{port}/redoc\n")
 
-    print("执行命令:", " ".join(base_cmd))
-    print(f"Swagger UI: http://{host}:{port}/docs")
-    print(f"ReDoc UI:   http://{host}:{port}/redoc")
     subprocess.run(base_cmd)
 
+
 if __name__ == "__main__":
     main()

この差分においてかなりの量のファイルが変更されているため、一部のファイルを表示していません