Bladeren bron

feat: update

jerry 4 maanden geleden
bovenliggende
commit
5fc3a15992

+ 1 - 0
.env

@@ -1,3 +1,4 @@
+ENV=DEV
 DATABASE_URL=mysql+asyncmy://root:GqLLL7Bofj0WaaOpp.0@visafly.top:3306/book_user_info?charset=utf8mb4
 REDIS_URL=redis://:STEs2x6ML0U1HlpE9SojM6YU7QPhqzY8@45.137.220.138:6379/0
 OPENAI_API_KEY=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

+ 30 - 0
Dockerfile

@@ -0,0 +1,30 @@
+# /root/backend/Dockerfile
+
+# 使用 Python 官方镜像
+FROM python:3.9-slim
+
+# 设置工作目录
+WORKDIR /app
+
+# 优化 Python 环境
+ENV PYTHONDONTWRITEBYTECODE=1
+ENV PYTHONUNBUFFERED=1
+
+# 安装系统依赖 (如有需要,例如数据库驱动)
+RUN apt-get update && apt-get install -y default-libmysqlclient-dev build-essential libffi-dev libssl-dev pkg-config
+
+# 复制依赖文件
+COPY requirements.txt .
+
+# 安装 Python 依赖
+# 使用清华源加速
+RUN pip install --no-cache-dir -r requirements.txt
+
+# 复制项目代码
+COPY . .
+
+# 暴露端口 (仅用于文档,实际由 docker-compose 映射)
+EXPOSE 8888
+
+# 启动命令
+CMD ["python3", "-u", "starter.py"]

+ 5 - 4
app/api/router.py

@@ -734,7 +734,7 @@ async def vas_order_create(
     schema = await SchemaService.get(db, product.schema_id)
     # ② 校验 user_inputs
     validate_user_inputs(
-        schema_json=schema.schema_json,
+        schema_json=schema.schema_content,
         user_inputs=payload.user_inputs,
     )
     created_order = await OrderService.create(db, payload, product, current_user, redis_client)
@@ -752,7 +752,7 @@ async def vas_order_create_by_admin(
     schema = await SchemaService.get(db, product.schema_id)
     # ② 校验 user_inputs
     validate_user_inputs(
-        schema_json=schema.schema_json,
+        schema_json=schema.schema_content,
         user_inputs=payload.user_inputs,
     )
     created_order = await OrderService.create_by_admin(db, payload, product, current_user, redis_client)
@@ -772,7 +772,7 @@ async def vas_order_patch_user_inputs(
     payload: VasOrderPatchUserInputs,
     db: Session = Depends(get_db),
 ):
-    order = OrderService.patch_user_inputs(db, order_id, payload)
+    order = await OrderService.patch_user_inputs(db, order_id, payload)
     return success(data=order)
 
 @protected_router.get("/vas/order/list_by_user", summary="查看所有订单", tags=["Visafly签证系统"], response_model=ApiResponse[PageResponse[VasOrderOut]])
@@ -803,7 +803,8 @@ async def vas_order_cancel(
     db: Session = Depends(get_db),
     redis_client: Redis = Depends(get_redis_client)
 ):
-    pass
+    cancelled_order = await OrderService.cancel(db, order_id)
+    return success(data=cancelled_order)
 
 @protected_router.get("/vas/payment_provider/list_enabled", summary="获取支付方式", tags=["Visafly签证系统"], response_model=ApiResponse[List[VasPaymentProviderOut]])
 async def vas_payment_provider_simple_get(

+ 1 - 0
app/core/config.py

@@ -8,6 +8,7 @@ class Settings(BaseSettings):
     # App
     # -----------------------
     app_name: str = "MyApp"
+    env: str = "PROD"
     debug: bool = False
 
     # -----------------------

+ 5 - 4
app/main.py

@@ -1,5 +1,5 @@
 import asyncio
-
+import os
 from fastapi import FastAPI, Depends, Request
 from fastapi.responses import JSONResponse
 from fastapi.middleware.cors import CORSMiddleware
@@ -30,9 +30,10 @@ async def startup():
     logger.info("🟢 Stripe config done")
     
     # 通知服务启动
-    redis_client = await get_redis_client()
-    asyncio.create_task(notification_consumer(AsyncSessionLocal, redis_client))
-    logger.info("🟢 Notification consumer started")    
+    if os.environ.get("RUN_ON_MASTER", "1") == "1" and os.getppid() == 1:
+        redis_client = await get_redis_client()
+        asyncio.create_task(notification_consumer(AsyncSessionLocal, redis_client))
+        logger.info("🟢 Notification consumer started")    
 
 # -----------------------
 # Exception Handlers

+ 2 - 2
app/schemas/common.py

@@ -1,7 +1,6 @@
 # app/schemas/common.py
 from typing import Generic, TypeVar, Optional, List
 from pydantic import BaseModel
-from pydantic.generics import GenericModel
 
 T = TypeVar("T")
 
@@ -10,7 +9,8 @@ class ApiResponse(BaseModel, Generic[T]):
     message: str = "success"
     data: Optional[T] = None
 
-class PageResponse(GenericModel, Generic[T]):
+# 将 GenericModel 替换为 BaseModel
+class PageResponse(BaseModel, Generic[T]):
     items: List[T]
     total: int
     page: int

+ 20 - 12
app/schemas/schema.py

@@ -1,14 +1,21 @@
-# app/schemas/schema.py
 import json
-from pydantic import BaseModel, field_validator
-from typing import Optional, Dict, Any, Literal, List
+from pydantic import BaseModel, field_validator, Field, ConfigDict
+from typing import Optional, Dict, Any
 from datetime import datetime
 
 class VasSchemaBase(BaseModel):
     name: Optional[str] = None
     description: Optional[str] = None
-    schema_json: Optional[Dict[str, Any]] = None
-    @field_validator("schema_json", mode="before")
+    
+    # 1. 将字段重命名为 schema_content (避免与 BaseModel.schema_json 冲突)
+    # 2. 使用 alias="schema_json" 确保前端传参或数据库读取时 key 仍然是 "schema_json"
+    schema_content: Optional[Dict[str, Any]] = Field(default=None, alias="schema_json")
+
+    # 配置允许通过字段名或别名进行赋值
+    model_config = ConfigDict(populate_by_name=True)
+
+    # 3. 验证器指向新的字段名 schema_content
+    @field_validator("schema_content", mode="before")
     def normalize_json_field(cls, v):
         if v is None:
             return None
@@ -18,22 +25,23 @@ class VasSchemaBase(BaseModel):
             except Exception:
                 return {}
         return v
-    
+
 class VasSchemaCreate(VasSchemaBase):
     name: str
     description: str
-    schema_json: Dict[str, Any]
+    # 这里也要重命名,并加上 alias
+    schema_content: Dict[str, Any] = Field(..., alias="schema_json")
 
 class VasSchemaUpdate(VasSchemaBase):
     pass
 
 class VasSchemaOut(VasSchemaBase):
     id: int
-    
     created_at: datetime
     updated_at: datetime
 
-    model_config = {
-        "from_attributes": True
-    }
-
+    # 继承 Base 的配置,并在其基础上增加 from_attributes
+    model_config = ConfigDict(
+        populate_by_name=True,
+        from_attributes=True
+    )

+ 1 - 1
app/services/llm_service.py

@@ -21,7 +21,7 @@ class LlmService:
         obj = (await db.execute(stmt)).scalar_one_or_none()
         if not obj:
             raise NotFoundError("Schema not exist")
-        parsed_obj = await LlmService.parse_data_async(payload.input_raw_str, obj.schema_json)
+        parsed_obj = await LlmService.parse_data_async(payload.input_raw_str, obj.schema_content)
         out = ParseUserInputsOut(parsed_obj=parsed_obj)
         return out
     

+ 48 - 37
app/services/order_service.py

@@ -17,6 +17,7 @@ from app.models.vas_task import VasTask
 from app.models.product import VasProduct
 from app.models.product_routing import VasProductRouting
 from app.schemas.order import VasOrderCreate, VasOrderPatchUserInputs
+from app.services.webhook_service import WebhookService
 
 
 class OrderService:
@@ -55,6 +56,34 @@ class OrderService:
 
         return order
     
+    # --------------------------------------------------
+    # 取消订单
+    # --------------------------------------------------
+    @staticmethod
+    async def cancel(
+        db: AsyncSession,
+        order_id,
+    ) -> VasOrder:
+        stmt = select(VasOrder).where(VasOrder.id == order_id)
+        order = (await db.execute(stmt)).scalar_one_or_none()
+
+        if not order:
+            raise NotFoundError("Order not exist")
+        
+        order.status = "closed"
+        task_res = await db.execute(
+            select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.status.in_(["pending", "grabbed", "running", "completed"]),
+            )
+        )
+        for task in task_res.scalars().all():
+            task.status = "cancelled"
+            
+        await db.commit()
+        await db.refresh(order)
+        return order
+    
     @staticmethod
     async def create_by_admin(
         db: AsyncSession,
@@ -93,41 +122,7 @@ class OrderService:
         
         db.add(order)
         
-        stmt = select(VasProductRouting).where(
-            VasProductRouting.product_id == order.product_id,
-            VasProductRouting.is_active == 1
-        )
-        result = await db.execute(stmt)
-        routings = result.scalars().all()
-
-        if routings:
-            created_tasks: List[VasTask] = []
-
-            for routing in routings:
-                exists_stmt = select(VasTask).where(
-                    VasTask.order_id == order.id,
-                    VasTask.routing_key == routing.routing_key,
-                    VasTask.script_version == routing.script_version,
-                )
-                exists = (await db.execute(exists_stmt)).scalar_one_or_none()
-                if exists:
-                    continue
-
-                task = VasTask(
-                    order_id=order.id,
-                    routing_key=routing.routing_key,
-                    script_version=routing.script_version,
-                    priority=routing.priority,
-                    status="pending",
-                    user_inputs=order.user_inputs,
-                    config=routing.config,
-                    attempt_count=0,
-                    notify_count=0,
-                    expire_at=datetime.utcnow() + timedelta(days=60),
-                    created_at=datetime.utcnow(),
-                )
-                db.add(task)
-                created_tasks.append(task)
+        await WebhookService._create_task_if_not_exists(db, order)
         
         await db.commit()
         await db.refresh(order)
@@ -157,7 +152,7 @@ class OrderService:
             stmt=stmt,
             model=VasOrder,
             keyword=keyword,
-            fields=["id", "user_id", "product_name"],
+            fields=["id", "user_id", "product_name", "user_inputs"],
         ).order_by(VasOrder.created_at.desc())
 
         return await paginate(db, stmt, page, size)
@@ -177,7 +172,7 @@ class OrderService:
             stmt=stmt,
             model=VasOrder,
             keyword=keyword,
-            fields=["id", "user_id", "user_name", "product_name"],
+            fields=["id", "user_id", "user_name", "product_name", "user_inputs"],
         ).order_by(VasOrder.created_at.desc())
 
         return await paginate(db, query, page, size)
@@ -199,6 +194,22 @@ class OrderService:
             raise NotFoundError("Order not exist")
 
         order.user_inputs = payload.user_inputs
+        
+
+        # 1️⃣ 取消旧任务
+        task_res = await db.execute(
+            select(VasTask).where(
+                VasTask.order_id == order.id,
+                VasTask.status.in_(["pending", "grabbed", "running", "completed"]),
+            )
+        )
+        for task in task_res.scalars().all():
+            task.status = "cancelled"
+        
+        await db.flush()
+        await WebhookService._create_task_if_not_exists(db, order)
+        
+        
         await db.commit()
         await db.refresh(order)
 

+ 1 - 2
app/services/payment_service.py

@@ -17,7 +17,6 @@ from app.utils.pagination import paginate
 from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.user import VasUser
 from app.models.order import VasOrder
-from app.models.vas_task import VasTask
 from app.models.payment import VasPayment
 from app.models.ticket import VasTicket
 from app.models.payment_event import VasPaymentEvent
@@ -299,7 +298,7 @@ class PaymentService:
 
         if order and order.status != "paid":
             order.status = "paid"
-
+        
         await WebhookService._create_task_if_not_exists(db, order)
 
         event.status = "applied"

+ 3 - 19
app/services/ticket_service.py

@@ -65,7 +65,6 @@ class TicketService:
         admin_id: str,
     ) -> VasTicket:
 
-        # 🔒 锁住 ticket,防并发管理员操作
         result = await db.execute(
             select(VasTicket)
             .where(VasTicket.id == ticket_id)
@@ -140,27 +139,12 @@ class TicketService:
             task_res = await db.execute(
                 select(VasTask).where(
                     VasTask.order_id == order.id,
-                    VasTask.status.in_(["pending", "grabbed", "running"]),
+                    VasTask.status.in_(["pending", "grabbed", "running", "completed"]),
                 )
             )
             for task in task_res.scalars().all():
                 task.status = "cancelled"
 
-        # ---------- 变更请求 ----------
-        elif ticket.type == "change_request":
-
-            # 1️⃣ 取消旧任务
-            task_res = await db.execute(
-                select(VasTask).where(
-                    VasTask.order_id == order.id,
-                    VasTask.status.in_(["pending", "grabbed", "running"]),
-                )
-            )
-            for task in task_res.scalars().all():
-                task.status = "cancelled"
-                
-            await WebhookService._create_task_if_not_exists(db, order)
-
     # =========================
     # 工单拒绝逻辑
     # =========================
@@ -182,8 +166,8 @@ class TicketService:
         if not order:
             return
 
-        if ticket.type == "refund" and order.status == "refund_pending":
-            order.status = "paid"
+        # if ticket.type == "refund" and order.status == "refund_pending":
+        #     order.status = "paid"
 
     @staticmethod
     async def add_message(

+ 106 - 69
app/services/troov_service.py

@@ -3,20 +3,18 @@ import time
 import random
 import asyncio
 import aiohttp
-from typing import List, Optional
+from typing import List, Optional, Tuple, Dict, Any
 
 from redis.asyncio import Redis
 from starlette.concurrency import run_in_threadpool
+from app.core.biz_exception import NotFoundError, PermissionDeniedError, BizLogicError
+
 from app.schemas.troov import TroovRate
 from app.utils.france_slot_api import troov_create_session_old
 from app.utils.proxy_utils import load_proxies_from_json
 from app.core.logger import logger
 
 
-# =========================================================
-# Redis 原子弹出 token(不使用 KEYS,避免阻塞)
-# =========================================================
-
 POP_TOKEN_LUA = """
 local cursor = "0"
 local max_ttl = -1
@@ -45,105 +43,144 @@ end
 return nil
 """
 
-
-async def pop_redis_value_token(redis_client: Redis):
+async def get_valid_token_from_redis(redis_client: Redis, timeout: int = 30) -> Optional[str]:
     """
-    原子性获取 TTL 最大的 token 并删除
+    尝试从 Redis 获取有效的验证码 Token。
+    包含重试机制。
     """
-    return await redis_client.eval(POP_TOKEN_LUA, 0)
+    start_time = time.time()
+    
+    while time.time() - start_time < timeout:
+        # 执行 Lua 脚本原子获取
+        result = await redis_client.eval(POP_TOKEN_LUA, 0)
+        
+        if result:
+            try:
+                # result 结构: [key, value_str, ttl]
+                body_str = result[1]
+                body = json.loads(body_str)
+                token = body.get("token")
+                if token:
+                    return token
+            except (json.JSONDecodeError, IndexError, AttributeError):
+                logger.warning("Redis retrieved invalid token format")
+        
+        # 没拿到或格式不对,稍作等待
+        await asyncio.sleep(1)
+    
+    return None
 
 
 # =========================================================
-# 请求法国 Troov 接口(async,不阻塞)
+# 2. 网络请求模块
 # =========================================================
 
-async def fetch_rate(session_dic: dict, date: str) -> str:
+async def fetch_troov_availability(
+    session_data: Dict[str, Any], 
+    date: str, 
+    proxy_url: str
+) -> str:
+    """
+    请求 Troov 预约可用性接口。
+    强制使用指定的代理。
+    """
     url = (
-        "https://api.consulat.gouv.fr/api/team/"
+        "https://51.254.177.49/api/team/"
         "621540d353069dec25bd0045/reservations/availability"
-        f"?name=Visas&date={date}&places=-5&matching=&maxCapacity=-5"
-        f"&sessionId={session_dic['session_id']}"
     )
+    
+    # URL 参数
+    params = {
+        "name": "Visas",
+        "date": date,
+        "places": "-5",
+        "matching": "",
+        "maxCapacity": "-5",
+        "sessionId": session_data.get("session_id")
+    }
 
     headers = {
         "accept": "application/json, text/plain, */*",
         "accept-language": "zh-CN,zh;q=0.9,en;q=0.8",
-        "origin": "https://consulat.gouv.fr",
-        "referer": "https://consulat.gouv.fr/en/ambassade-de-france-en-irlande/appointment?name=Visas",
+        # "origin": "https://consulat.gouv.fr",
+        # "referer": "https://consulat.gouv.fr/en/ambassade-de-france-en-irlande/appointment?name=Visas",
         "user-agent": (
             "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
             "AppleWebKit/537.36 (KHTML, like Gecko) "
             "Chrome/141.0.0.0 Safari/537.36"
         ),
-        "x-gouv-app-id": session_dic["x_gouv_app_id"],
+        "x-gouv-app-id": session_data.get("x_gouv_app_id"),
         "x-gouv-web": "fr.gouv.consulat",
     }
 
     timeout = aiohttp.ClientTimeout(total=15)
-
-    async with aiohttp.ClientSession(timeout=timeout) as session:
-        async with session.get(url, headers=headers) as resp:
+    
+    connector = aiohttp.TCPConnector(ssl=False)
+
+    # 显式使用传入的 proxy_url
+    async with aiohttp.ClientSession(timeout=timeout, connector=connector) as session:
+        async with session.get(
+            url, 
+            params=params, 
+            headers=headers, 
+            proxy=proxy_url
+        ) as resp:
+            resp.raise_for_status() # 如果状态码不是 200,抛出异常
             return await resp.text()
 
 
 # =========================================================
-# 核心业务逻辑
+# 3. 核心业务流程
 # =========================================================
 
+def _get_proxy_pool() -> List[str]:
+    """加载代理池配置"""
+    proxies = []
+    # 可以在此处扩展更多 pool 类型
+    for pool in ("oxylabs",):
+        proxies.extend(load_proxies_from_json("data/proxy_pool_config.json", pool))
+    return proxies
+
+
 async def get_rate_by_date(
     redis_client: Redis,
     date: str
 ) -> Optional[List[TroovRate]]:
     """
-    根据日期获取 Troov 预约信息
+    主入口:根据日期获取 Troov 预约信息
+    流程:获取代理 -> 获取 Token -> 创建会话(Sync) -> 获取数据(Async)
     """
 
-    # ---------- 1️⃣ 加载代理 ----------
-    proxies = []
-    for pool in ("oxylabs",):
-        proxies.extend(
-            load_proxies_from_json("data/proxy_pool_config.json", pool)
-        )
-
+    # 1. 准备代理
+    proxies = _get_proxy_pool()
     if not proxies:
-        logger.error("Proxy pool is empty")
-        return None
-
-    # ---------- 2️⃣ 获取验证码 token(最多等待 30 秒) ----------
-    token_data = None
-    for _ in range(30):
-        token_data = await pop_redis_value_token(redis_client)
-        if token_data:
-            break
-        await asyncio.sleep(1)
-
-    if not token_data:
-        logger.warning("No captcha token available")
-        return None
-
-    _, body_str, ttl = token_data
-
-    try:
-        body = json.loads(body_str)
-        captcha_token = body.get("token")
-    except Exception:
-        logger.exception("Invalid captcha token format")
-        return None
-
-    # ---------- 3️⃣ 创建 Troov session(同步函数放线程池) ----------
-    proxy = random.choice(proxies)
-
-    session_dic = await run_in_threadpool(troov_create_session_old, proxy, captcha_token)
+        raise NotFoundError(message="Proxy pool is empty")
+    
+    # 随机选择一个代理,并在整个流程中保持一致
+    current_proxy = random.choice(proxies)
+
+    # 2. 获取验证码 Token
+    captcha_token = await get_valid_token_from_redis(redis_client)
+    if not captcha_token:
+        raise NotFoundError(message="Failed to retrieve captcha token within timeout")
+
+
+    logger.info(f"Creating session with proxy: {current_proxy}...")
+    session_dic = await run_in_threadpool(
+        troov_create_session_old, 
+        current_proxy, 
+        captcha_token
+    )
+    
     if not session_dic:
-        logger.warning("Failed to create Troov session")
-        return None
-
-    logger.info(f"Troov session created: {session_dic}")
-
-    # ---------- 4️⃣ 请求预约数据 ----------
-    try:
-        response_text = await fetch_rate(session_dic, date)
-        return json.loads(response_text)
-    except Exception as e:
-        logger.error(f"Fetch rate failed: {e}")
-        return None
+        raise BizLogicError(message="Failed to create Troov session (session_dic is empty)")
+        
+    logger.info(f"Troov session created successfully: {session_dic.get('session_id')}")
+
+    # 确保这里传入了 current_proxy
+    response_text = await fetch_troov_availability(session_dic, date, current_proxy)
+    
+    # 解析数据
+    data = json.loads(response_text)
+    # 这里可以加一步数据校验,确保 data 是 List[TroovRate] 格式
+    return data

+ 9 - 2
app/services/webhook_service.py

@@ -7,6 +7,7 @@ from decimal import Decimal
 from sqlalchemy.ext.asyncio import AsyncSession
 from sqlalchemy import select
 
+from app.core.queue_manager import queue_manager
 from app.core.biz_exception import NotFoundError, BizLogicError
 from app.models.order import VasOrder
 from app.models.vas_task import VasTask
@@ -44,7 +45,7 @@ class WebhookService:
             exists_stmt = select(VasTask).where(
                 VasTask.order_id == order.id,
                 VasTask.routing_key == routing.routing_key,
-                VasTask.script_version == routing.script_version,
+                VasTask.status.in_(["pending", "grabbed", "running", "completed"]),
             )
             exists_result = await db.execute(exists_stmt)
             exists = exists_result.scalar_one_or_none()
@@ -66,8 +67,14 @@ class WebhookService:
                 created_at=datetime.utcnow(),
             )
             db.add(task)
+            await db.flush()
+            await db.refresh(task)
+            queue_manager.put(
+                queue_name=routing.routing_key,
+                task_id=task.id,
+                priority=task.priority
+            )
             created_tasks.append(task)
-
         return created_tasks
 
     # =========================================================

+ 23 - 0
docker-compose.yml

@@ -0,0 +1,23 @@
+version: '3.8'
+
+services:
+  backend:
+    container_name: visafly-backend
+    build: 
+      context: .
+      dockerfile: Dockerfile
+    restart: always
+    
+    # 端口映射:将容器的 8000 映射到宿主机的 8000
+    # 绑定到 127.0.0.1 保证安全性,只允许宿主机的 Nginx 访问
+    ports:
+      - "127.0.0.1:8888:8888"
+    
+    # 加载 .env 文件中的变量
+    env_file:
+      - .env
+
+    # 这是一个技巧:允许容器通过 'host.docker.internal' 访问宿主机
+    # 如果你的 DB_HOST 填公网 IP 连不上,可以试着填 host.docker.internal
+    extra_hosts:
+      - "host.docker.internal:host-gateway"

+ 9 - 0
requirements.txt

@@ -1,8 +1,17 @@
 fastapi
+requests
+aiohttp
+pysocks
+bcrypt
+stripe
+python-multipart
+jsonschema
+asyncmy
 uvicorn[standard]
 sqlalchemy
 pydantic
 pydantic-settings
+pydantic[email]
 psycopg2-binary  # PostgreSQL
 mysqlclient       # MySQL
 redis

+ 1 - 1
starter.py

@@ -17,7 +17,7 @@ def main():
     
     env = os.getenv("ENV", "DEV").upper()
     
-    env = "DEV"
+    # env = "DEV"
 
     host = "0.0.0.0"
     port = "8888"