server.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. import re
  2. import asyncio
  3. import uvicorn
  4. import tempfile
  5. from pathlib import Path
  6. from contextlib import asynccontextmanager
  7. from fastapi import FastAPI, Body, Request, Query, HTTPException
  8. from fastapi.responses import JSONResponse
  9. from fastapi.concurrency import run_in_threadpool
  10. from utils.browser_util import get_browser
  11. from toolkit.ocr_engine import PyTorchEngine, DddOcrEngine
  12. # ================= 全局资源 =================
  13. # 全局浏览器对象
  14. GLOBAL_PAGE = None
  15. # 异步锁,用于互斥控制
  16. BROWSER_LOCK = asyncio.Lock()
  17. # OCR 引擎字典
  18. engines = {}
  19. def _sync_get_visatype_ids(tmp_file: str):
  20. """
  21. 这是实际执行浏览器操作的同步函数。
  22. 它会在独立的线程中运行,不会阻塞服务器。
  23. """
  24. result = {"status": "failed", "message": ""}
  25. try:
  26. html_file_path = Path(tmp_file).resolve()
  27. file_url = f'file://{html_file_path}'
  28. GLOBAL_PAGE.get(file_url)
  29. jur_id = None
  30. loc_id = None
  31. type_id = None
  32. subtype_id = None
  33. cat_id = None
  34. # 匹配 ID
  35. app_category_labels = GLOBAL_PAGE.eles(f'Appointment Category', timeout=1)
  36. for app_category_label in app_category_labels:
  37. if app_category_label.states.has_rect and app_category_label.tag == 'label':
  38. eid = app_category_label.after('tag:input').attr('id')
  39. cat_id = int(''.join(filter(str.isdigit, eid)))
  40. break
  41. jurisdiction_labels = GLOBAL_PAGE.eles(f'Jurisdiction', timeout=1)
  42. if jurisdiction_labels:
  43. for jurisdiction_label in jurisdiction_labels:
  44. if jurisdiction_label.states.has_rect and jurisdiction_label.tag == 'label':
  45. eid = jurisdiction_label.after('tag:input').attr('id')
  46. jur_id = int(''.join(filter(str.isdigit, eid)))
  47. break
  48. location_labels = GLOBAL_PAGE.eles(f'Location', timeout=1)
  49. for location_label in location_labels:
  50. if location_label.states.has_rect and location_label.tag == 'label':
  51. eid = location_label.after('tag:input', index=2).attr('id')
  52. loc_id = int(''.join(filter(str.isdigit, eid)))
  53. break
  54. visa_type_labels = GLOBAL_PAGE.eles(f'Visa Type', timeout=1)
  55. for visa_type_label in visa_type_labels:
  56. if visa_type_label.states.has_rect and visa_type_label.tag == 'label':
  57. eid = visa_type_label.after('tag:input').attr('id')
  58. type_id = int(''.join(filter(str.isdigit, eid)))
  59. break
  60. visa_subtype_labels = GLOBAL_PAGE.eles(f'Visa Sub Type', timeout=1)
  61. for visa_subtype_label in visa_subtype_labels:
  62. if visa_subtype_label.states.has_rect and visa_subtype_label.tag == 'label':
  63. eid = visa_subtype_label.after('tag:input').attr('id')
  64. subtype_id = int(''.join(filter(str.isdigit, eid)))
  65. break
  66. data = {
  67. "jur_id": jur_id,
  68. "loc_id": loc_id,
  69. "type_id": type_id,
  70. "subtype_id": subtype_id,
  71. "cat_id": cat_id,
  72. }
  73. result["status"] = "success"
  74. result['data'] = data
  75. except Exception as e:
  76. result["message"] = str(e)
  77. print(f"[DrissionPage] Error: {e}")
  78. return result
  79. def _sync_get_visable_image_ids(tmp_file: str):
  80. """
  81. 这是实际执行浏览器操作的同步函数。
  82. 它会在独立的线程中运行,不会阻塞服务器。
  83. """
  84. result = {"status": "failed", "message": ""}
  85. try:
  86. images_ids = []
  87. html_file_path = Path(tmp_file).resolve()
  88. file_url = f'file://{html_file_path}'
  89. GLOBAL_PAGE.get(file_url)
  90. captions_ele = GLOBAL_PAGE.ele('xpath://*[@id="captcha-main-div"]/div/div[1]', timeout=5)
  91. if not captions_ele:
  92. raise Exception('Captions elements not found')
  93. caption_eles = captions_ele.children()
  94. caption_text = ''
  95. for caption in caption_eles:
  96. if not caption.states.is_covered:
  97. caption_text = caption.text
  98. number = re.findall(r'\d+', caption_text)[0]
  99. captcha_images_ele = GLOBAL_PAGE.ele('xpath://*[@id="captcha-main-div"]/div/div[2]')
  100. captcha_image_eles = captcha_images_ele.children()
  101. for captcha_image in captcha_image_eles:
  102. img = captcha_image.ele('.captcha-img')
  103. if img.states.has_rect and img.states.is_covered == False:
  104. img_src = img.attr('src')
  105. if img_src and img_src.startswith('data:image'):
  106. images_ids.append(captcha_image.attr('id'))
  107. data = {
  108. "number": number,
  109. "image_ids": images_ids,
  110. }
  111. result["status"] = "success"
  112. result['data'] = data
  113. except Exception as e:
  114. result["message"] = str(e)
  115. print(f"[DrissionPage] Error: {e}")
  116. return result
  117. # ================= 2. 生命周期管理 =================
  118. @asynccontextmanager
  119. async def lifespan(app: FastAPI):
  120. # --- 启动 OCR (伪代码,请保留你之前的逻辑) ---
  121. print("--- Loading OCR Models ---")
  122. engines['pytorch'] = PyTorchEngine('data/ctc.pth')
  123. engines['ddddocr'] = DddOcrEngine()
  124. # --- 启动 DrissionPage ---
  125. print("--- Starting DrissionPage ---")
  126. global GLOBAL_PAGE
  127. # 创建浏览器对象,连接浏览器
  128. GLOBAL_PAGE = get_browser()
  129. yield
  130. # --- 关闭资源 ---
  131. print("--- Shutting Down ---")
  132. if GLOBAL_PAGE:
  133. try:
  134. GLOBAL_PAGE.quit() # 关闭浏览器
  135. except:
  136. pass
  137. engines.clear()
  138. app = FastAPI(lifespan=lifespan)
  139. # ================= 3. 浏览器接口 (带忙碌检测) =================
  140. @app.post("/browser/visable_captchas")
  141. async def browser_get_data(html_content: str = Body(..., media_type="text/plain")
  142. ):
  143. # 1. 非阻塞检查:锁是否被占用
  144. if BROWSER_LOCK.locked():
  145. return JSONResponse(
  146. status_code=503,
  147. content={
  148. "code": 503,
  149. "status": "busy",
  150. "msg": "Browser is busy. One task at a time."
  151. }
  152. )
  153. # 2. 获取锁
  154. async with BROWSER_LOCK:
  155. print(f"[Browser] Processing")
  156. # 3. 写入临时 HTML 文件
  157. with tempfile.NamedTemporaryFile(
  158. mode="w+",
  159. suffix=".html",
  160. delete=True,
  161. encoding="utf-8"
  162. ) as f:
  163. f.write(html_content)
  164. f.flush()
  165. # 3. 核心:将同步的 DrissionPage 代码扔到线程池运行
  166. # 这样主线程(处理 OCR 请求的线程)不会被卡死
  167. result = await run_in_threadpool(_sync_get_visable_image_ids, f.name)
  168. return result
  169. # ================= 3. 浏览器接口 (带忙碌检测) =================
  170. @app.post("/browser/visatype_visable")
  171. async def browser_get_data(html_content: str = Body(..., media_type="text/plain")
  172. ):
  173. # 1. 非阻塞检查:锁是否被占用
  174. if BROWSER_LOCK.locked():
  175. return JSONResponse(
  176. status_code=503,
  177. content={
  178. "code": 503,
  179. "status": "busy",
  180. "msg": "Browser is busy. One task at a time."
  181. }
  182. )
  183. # 2. 获取锁
  184. async with BROWSER_LOCK:
  185. print(f"[Browser] Processing")
  186. # 3. 写入临时 HTML 文件
  187. with tempfile.NamedTemporaryFile(
  188. mode="w+",
  189. suffix=".html",
  190. delete=True,
  191. encoding="utf-8"
  192. ) as f:
  193. f.write(html_content)
  194. f.flush()
  195. # 3. 核心:将同步的 DrissionPage 代码扔到线程池运行
  196. # 这样主线程(处理 OCR 请求的线程)不会被卡死
  197. result = await run_in_threadpool(_sync_get_visatype_ids, f.name)
  198. return result
  199. # ================= 路由 2: OCR 识别 (BLS) =================
  200. @app.post("/predict/bls")
  201. async def predict_bls(request: Request, model: str = Query("ddddocr", enum=["ddddocr", "pytorch"])):
  202. """ 处理 BLS 验证码 """
  203. try:
  204. image_bytes = await request.body()
  205. if not image_bytes:
  206. raise HTTPException(status_code=400, detail="Empty body")
  207. if model == 'ddddocr':
  208. res = engines['ddddocr'].inference_bytes(image_bytes)
  209. else:
  210. res = engines['pytorch'].inference_bytes(image_bytes)
  211. return {"code": 200, "msg": "success", "data": res, "engine": model}
  212. except Exception as e:
  213. return JSONResponse(status_code=500, content={"code": 500, "msg": str(e), "data": ""})
  214. # ================= 路由 3: OCR 识别 (Visametric) =================
  215. @app.post("/predict/visametric")
  216. async def predict_visametric(request: Request):
  217. """ 处理 Visametric 验证码 (特殊预处理) """
  218. try:
  219. image_bytes = await request.body()
  220. res = engines['ddddocr'].inference_captcha(image_bytes)
  221. return {"code": 200, "msg": "success", "data": res}
  222. except Exception as e:
  223. return JSONResponse(status_code=500, content={"code": 500, "msg": str(e), "data": ""})
  224. if __name__ == '__main__':
  225. # 运行服务
  226. # host='0.0.0.0' 允许局域网访问
  227. print("API Documentation: http://127.0.0.1:8085/docs")
  228. uvicorn.run(app, host='0.0.0.0', port=8085)