jerry пре 1 месец
родитељ
комит
0f4487bcde

+ 55 - 0
app/api/router.py

@@ -56,6 +56,7 @@ from app.schemas.llm import ParseUserInputsPayload, ParseUserInputsOut
 from app.schemas.account import AccountResponse, AccountCreate, LockRequest
 from app.schemas.docker_remote import RemoteServerConfig, DockerStatusOut, DockerLogsRequest, DockerLogsOut, ConfigReadOut, ConfigReadRequest, ConfigUpdateRequest, LogReadRequest, LogReadOut, LogListOut, DockerContainerStatus, DockerActionRequest, ServerConfigItem, ServerListOut, RemoteActionRequest
 from app.schemas.order_event import VasOrderEventCreate, VasOrderEventOut
+from app.schemas.troov_session import TroovSessionCreate, TroovSessionUpdate, TroovSessionOut
 from app.services.docker_remote_service import DockerRemoteService
 from app.services.configuration_service import ConfigurationService
 from app.services.troov_service import TroovService
@@ -92,6 +93,7 @@ from app.services.llm_service import LlmService
 from app.services.slot_refresh_status_service import SlotRefreshStatusService
 from app.services.account_service import AccountService
 from app.services.order_event_service import OrderEventService
+from app.services.troov_session_service import TroovSessionService
 
 # 公共路由
 public_router = APIRouter()
@@ -1564,3 +1566,56 @@ async def get_ticket_messages(
         size=size
     )
     return success(data=msgs)
+
+
+# -----------------------
+# Troov Session APIs
+# -----------------------
+@admin_required_router.post("/troov-session/add", summary="新增troov session", tags=["Troov"], response_model=ApiResponse[TroovSessionOut])
+async def troov_session_add(
+    payload: TroovSessionCreate,
+    db: AsyncSession = Depends(get_db)
+):
+    obj = await TroovSessionService.add(db, payload)
+    return success(data=obj)
+
+
+@admin_required_router.get("/troov-session/pop", summary="获取并锁定一个pending的troov session", tags=["Troov"], response_model=ApiResponse[TroovSessionOut])
+async def troov_session_pop(
+    slot_date: str = Query("", description="slot日期筛选"),
+    slot_time: str = Query("", description="slot时间筛选"),
+    db: AsyncSession = Depends(get_db)
+):
+    obj = await TroovSessionService.pop(db, slot_date or None, slot_time or None)
+    return success(data=obj)
+
+
+@admin_required_router.put("/troov-session/update", summary="更新troov session", tags=["Troov"], response_model=ApiResponse[TroovSessionOut])
+async def troov_session_update(
+    session_id: str = Query(..., description="session_id"),
+    payload: TroovSessionUpdate = Body(...),
+    db: AsyncSession = Depends(get_db)
+):
+    obj = await TroovSessionService.update(db, session_id, payload)
+    return success(data=obj)
+
+
+@admin_required_router.get("/troov-session/get", summary="根据session_id获取troov session", tags=["Troov"], response_model=ApiResponse[TroovSessionOut])
+async def troov_session_get(
+    session_id: str = Query(..., description="session_id"),
+    db: AsyncSession = Depends(get_db)
+):
+    obj = await TroovSessionService.get_by_session_id(db, session_id)
+    return success(data=obj)
+
+
+@admin_required_router.get("/troov-session/list", summary="分页获取troov session列表", tags=["Troov"], response_model=PageResponse[List[TroovSessionOut]])
+async def troov_session_list(
+    status: str = Query("", description="状态筛选"),
+    page: int = Query(0, description="第几页"),
+    size: int = Query(10, description="分页大小"),
+    keyword: str = Query("", description="查询条件"),
+    db: AsyncSession = Depends(get_db)
+):
+    obj = await TroovSessionService.list(db, status, keyword, page, size)
+    return success(data=obj)

+ 17 - 0
app/models/troov_session.py

@@ -0,0 +1,17 @@
+from datetime import datetime
+from sqlalchemy import Column, String, DateTime, JSON, func
+from app.core.database import Base
+
+
+class TroovSession(Base):
+    __tablename__ = "troov_session"
+
+    session_id = Column(String(128), primary_key=True)
+    slot_date = Column(String(64), nullable=False)
+    slot_time = Column(String(64), nullable=False)
+    source = Column(String(128), nullable=False)
+    data = Column(JSON, nullable=True)
+    status = Column(String(32), default="pending", comment="pending, booking, expired")
+    
+    created_at = Column(DateTime, default=datetime.utcnow)
+    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)

+ 1 - 0
app/schemas/task.py

@@ -1,3 +1,4 @@
+import json
 from pydantic import BaseModel, field_validator
 from typing import Optional, Any, Dict
 from datetime import datetime

+ 46 - 0
app/schemas/troov_session.py

@@ -0,0 +1,46 @@
+import json
+from pydantic import BaseModel, field_validator
+from typing import Optional, Any, Dict
+from datetime import datetime
+
+
+class TroovSessionBase(BaseModel):
+    slot_date: str
+    slot_time: str
+    session_id: str
+    source: str
+    data: Optional[Dict[str, Any]] = None
+    status: str = "pending"
+    
+            
+    @field_validator("data", mode="before")
+    def normalize_json_field(cls, v):
+        if v is None:
+            return None
+        if isinstance(v, str):
+            try:
+                return json.loads(v)
+            except Exception:
+                return {}
+        return v
+
+
+class TroovSessionCreate(TroovSessionBase):
+    pass
+
+
+class TroovSessionUpdate(BaseModel):
+    slot_date: Optional[str] = None
+    slot_time: Optional[str] = None
+    session_id: Optional[str] = None
+    source: Optional[str] = None
+    data: Optional[Dict[str, Any]] = None
+    status: Optional[str] = None
+
+
+class TroovSessionOut(TroovSessionBase):
+    created_at: datetime
+    updated_at: datetime
+    model_config = {
+        "from_attributes": True
+    }

+ 100 - 0
app/services/troov_session_service.py

@@ -0,0 +1,100 @@
+from typing import Optional, Any
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy import select
+from app.core.biz_exception import NotFoundError
+from app.models.troov_session import TroovSession
+from app.schemas.troov_session import TroovSessionCreate, TroovSessionUpdate
+
+# 假设你的项目中分页和搜索工具方法路径如下,请根据实际情况调整
+from app.utils.pagination import paginate 
+from app.utils.search import apply_keyword_search_stmt
+
+
+class TroovSessionService:
+
+    @staticmethod
+    async def add(db: AsyncSession, obj_in: TroovSessionCreate) -> TroovSession:
+        # 优化点 1: 使用 model_dump 直接解包,避免手动挨个赋值
+        db_obj = TroovSession(**obj_in.model_dump())
+        
+        # 默认值已经在 schema 中定义了,这里不需要再 obj_in.status or "pending"
+        db.add(db_obj)
+        await db.commit()
+        await db.refresh(db_obj)
+        return db_obj
+
+    @staticmethod
+    async def pop(db: AsyncSession, slot_date: Optional[str] = None, slot_time: Optional[str] = None) -> TroovSession:
+        stmt = select(TroovSession).where(TroovSession.status == "pending")
+        if slot_date:
+            stmt = stmt.where(TroovSession.slot_date == slot_date)
+        if slot_time:
+            stmt = stmt.where(TroovSession.slot_time == slot_time)
+            
+        # 优点保留: 使用 skip_locked=True 完美解决并发争抢问题
+        stmt = stmt.order_by(TroovSession.created_at.asc()).limit(1).with_for_update(skip_locked=True)
+
+        result = await db.execute(stmt)
+        obj = result.scalar_one_or_none()
+
+        if not obj:
+            raise NotFoundError("No pending troov session found")
+
+        obj.status = "booking"
+        await db.commit()
+        await db.refresh(obj)
+        return obj
+
+    @staticmethod
+    async def update(db: AsyncSession, session_id: str, obj_in: TroovSessionUpdate) -> TroovSession:
+        stmt = select(TroovSession).where(TroovSession.session_id == session_id)
+        db_obj = (await db.execute(stmt)).scalar_one_or_none()
+
+        if not db_obj:
+            raise NotFoundError("TroovSession not found")
+
+        # 优化点 2: 使用 exclude_unset=True 仅提取客户端真正传入的字段
+        update_data = obj_in.model_dump(exclude_unset=True)
+        
+        # 安全防御: 防止客户端恶意或错误修改主键
+        update_data.pop("session_id", None) 
+
+        # 动态赋值,后续 Schema 加字段这里完全不用改代码
+        for field, value in update_data.items():
+            setattr(db_obj, field, value)
+
+        await db.commit()
+        await db.refresh(db_obj)
+        return db_obj
+
+    @staticmethod
+    async def get_by_session_id(db: AsyncSession, session_id: str) -> Optional[TroovSession]:
+        stmt = select(TroovSession).where(TroovSession.session_id == session_id)
+        return (await db.execute(stmt)).scalar_one_or_none()
+
+    @staticmethod
+    async def list(
+        db: AsyncSession, 
+        status: Optional[str] = None, 
+        keyword: Optional[str] = None, # 新增 keyword 支持
+        page: int = 1,                 # 分页库通常 page 从 1 开始
+        size: int = 10
+    ) -> Any:
+        stmt = select(TroovSession)
+        
+        if status:
+            stmt = stmt.where(TroovSession.status == status)
+
+        # 优化点 3: 引入你提供的 apply_keyword_search_stmt 范式
+        if keyword:
+            stmt = apply_keyword_search_stmt(
+                stmt=stmt,
+                model=TroovSession,
+                keyword=keyword,
+                fields=["session_id", "source"] # 根据业务需求定义支持模糊搜索的字段
+            )
+
+        stmt = stmt.order_by(TroovSession.created_at.desc())
+        
+        # 优化点 4: 使用统一的 paginate 方法取代手动 offset 和 limit
+        return await paginate(db, stmt, page, size)