llm_service.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import aiohttp
  2. import asyncio
  3. import json
  4. import os
  5. from sqlalchemy.ext.asyncio import AsyncSession
  6. from sqlalchemy import select
  7. from app.core.config import settings
  8. from app.schemas.llm import ParseUserInputsPayload, ParseUserInputsOut
  9. from app.models.schema import VasSchema
  10. # --- 配置区 ---
  11. # 请换成你新生成的 API Key
  12. API_KEY = settings.openai_api_key
  13. API_URL = "https://api.openai.com/v1/chat/completions"
  14. class LlmService:
  15. async def handle_parse(db: AsyncSession, payload: ParseUserInputsPayload):
  16. stmt = select(VasSchema).where(VasSchema.id == payload.schema_id)
  17. obj = (await db.execute(stmt)).scalar_one_or_none()
  18. if not obj:
  19. raise NotFoundError("Schema not exist")
  20. parsed_obj = await LlmService.parse_data_async(payload.input_raw_str, obj.schema_content)
  21. out = ParseUserInputsOut(parsed_obj=parsed_obj)
  22. return out
  23. @staticmethod
  24. async def parse_data_async(user_text: str, json_schema: dict):
  25. """
  26. [异步版本] 调用 LLM 解析数据
  27. """
  28. headers = {
  29. "Authorization": f"Bearer {API_KEY}",
  30. "Content-Type": "application/json"
  31. }
  32. # 构造 Prompt
  33. system_instruction = "You are a specialized data extraction API. Output valid JSON only."
  34. user_prompt = f"""
  35. Extract data from the text strictly based on the provided JSON Schema.
  36. [JSON Schema]
  37. {json.dumps(json_schema)}
  38. [User Text]
  39. {user_text}
  40. """
  41. payload = {
  42. "model": "gpt-4o", # 或 gpt-3.5-turbo
  43. "messages": [
  44. {"role": "system", "content": system_instruction},
  45. {"role": "user", "content": user_prompt}
  46. ],
  47. "temperature": 0,
  48. "response_format": {"type": "json_object"} # 强制 JSON 模式
  49. }
  50. async with aiohttp.ClientSession() as session:
  51. async with session.post(API_URL, headers=headers, json=payload, timeout=30) as response:
  52. # 1. 检查 HTTP 状态码
  53. if response.status != 200:
  54. error_text = await response.text()
  55. return {"error": f"HTTP {response.status}", "detail": error_text}
  56. # 2. 获取响应体
  57. result = await response.json()
  58. # 3. 提取并解析内容
  59. content_str = result['choices'][0]['message']['content']
  60. return json.loads(content_str)
  61. # --- 测试运行 ---
  62. if __name__ == "__main__":
  63. # 定义你的 Schema
  64. my_schema = {
  65. "type": "object",
  66. "properties": {
  67. "full_name": {"type": "string"},
  68. "budget": {"type": "integer"},
  69. "items": {"type": "array", "items": {"type": "string"}}
  70. },
  71. "required": ["full_name", "budget"]
  72. }
  73. # 模拟用户输入
  74. user_input = "我是张三,打算花2000块钱买个耳机和键盘。"
  75. print("正在解析...")
  76. result = parse_data_api(user_input, my_schema)
  77. print(json.dumps(result, ensure_ascii=False, indent=2))