predict_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import os
  2. # 屏蔽 ONNX Runtime 的警告日志
  3. os.environ["ORT_LOGGING_LEVEL"] = "3"
  4. import json
  5. import string
  6. import socket
  7. import traceback
  8. import io # 新增
  9. from http.server import BaseHTTPRequestHandler, HTTPServer
  10. from io import BytesIO
  11. from collections import OrderedDict
  12. from urllib.parse import urlparse, parse_qs
  13. # 图像处理依赖
  14. import cv2
  15. import numpy as np
  16. from PIL import Image, ImageFilter # 新增 ImageFilter
  17. # 深度学习依赖
  18. import torch
  19. from torch import nn
  20. from torchvision import transforms
  21. # ddddocr 依赖
  22. try:
  23. import ddddocr
  24. HAS_DDDDOCR = True
  25. except ImportError:
  26. print("[WARNING] ddddocr not installed. Run 'pip install ddddocr'")
  27. HAS_DDDDOCR = False
  28. # ================= 核心优化:图像去噪 (BLS专用) =================
  29. def advanced_denoise(image_bytes):
  30. """
  31. 针对 BLS 验证码的去噪流程
  32. """
  33. try:
  34. nparr = np.frombuffer(image_bytes, np.uint8)
  35. img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
  36. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  37. gray_blur = cv2.medianBlur(gray, 3)
  38. binary = cv2.adaptiveThreshold(
  39. gray_blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
  40. cv2.THRESH_BINARY, 11, 2
  41. )
  42. contours, _ = cv2.findContours(255 - binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  43. clean_img = np.ones(binary.shape, dtype="uint8") * 255
  44. for cnt in contours:
  45. area = cv2.contourArea(cnt)
  46. if 30 < area < 1000:
  47. cv2.drawContours(clean_img, [cnt], -1, 0, -1)
  48. return Image.fromarray(clean_img)
  49. except Exception as e:
  50. print(f"[Denoise] Error: {e}")
  51. return Image.open(BytesIO(image_bytes))
  52. # ================= PyTorch 模型结构 =================
  53. class Model(nn.Module):
  54. def __init__(self, n_classes, input_shape=(3, 64, 128)):
  55. super(Model, self).__init__()
  56. self.input_shape = input_shape
  57. channels = [32, 64, 128, 256, 256]
  58. layers = [2, 2, 2, 2, 2]
  59. kernels = [3, 3, 3, 3, 3]
  60. pools = [2, 2, 2, 2, (2, 1)]
  61. modules = OrderedDict()
  62. def cba(name, in_channels, out_channels, kernel_size):
  63. modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
  64. padding=(1, 1) if kernel_size == 3 else 0)
  65. modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
  66. modules[f'relu{name}'] = nn.ReLU(inplace=True)
  67. last_channel = 3
  68. for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
  69. for layer in range(1, n_layer + 1):
  70. cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
  71. last_channel = n_channel
  72. modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
  73. modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
  74. self.cnn = nn.Sequential(modules)
  75. self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
  76. self.fc = nn.Linear(in_features=256, out_features=n_classes)
  77. def infer_features(self):
  78. x = torch.zeros((1,)+self.input_shape)
  79. x = self.cnn(x)
  80. x = x.reshape(x.shape[0], -1, x.shape[-1])
  81. return x.shape[1]
  82. def forward(self, x):
  83. x = self.cnn(x)
  84. x = x.reshape(x.shape[0], -1, x.shape[-1])
  85. x = x.permute(2, 0, 1)
  86. x, _ = self.lstm(x)
  87. x = self.fc(x)
  88. return x
  89. # ================= 引擎1: PyTorch =================
  90. class PyTorchEngine:
  91. def __init__(self, model_path):
  92. self.num_classes = 12
  93. self.characters = '-' + string.digits + '$'
  94. self.width = 150
  95. self.hight = 80
  96. self.model = Model(self.num_classes, input_shape=(3, self.hight, self.width))
  97. if os.path.exists(model_path):
  98. self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
  99. self.model.eval()
  100. print(f"[PyTorch] Model loaded successfully from {model_path}")
  101. self.ready = True
  102. else:
  103. print(f"[PyTorch] Warning: Model file not found at {model_path}")
  104. self.ready = False
  105. self.transforms_func = transforms.Compose([
  106. transforms.Resize((self.hight, self.width)),
  107. transforms.ToTensor()
  108. ])
  109. def decode(self, sequence):
  110. a = ''.join([self.characters[x] for x in sequence])
  111. s = []
  112. last = None
  113. for x in a:
  114. if x != last:
  115. s.append(x)
  116. last = x
  117. s2 = ''.join([x for x in s if x != self.characters[0]])
  118. return s2
  119. def inference_bytes(self, image_bytes):
  120. if not self.ready:
  121. return "Error: Model not loaded"
  122. try:
  123. # === 恢复:直接使用 PIL 打开图片,移除 advanced_denoise ===
  124. image = Image.open(BytesIO(image_bytes))
  125. image = image.convert('RGB')
  126. if self.transforms_func is not None:
  127. image = self.transforms_func(image)
  128. with torch.no_grad():
  129. output = self.model(image.unsqueeze(0).cpu())
  130. output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
  131. predict_label = self.decode(output_argmax[0])
  132. return predict_label
  133. except Exception as e:
  134. print(f"[PyTorch] Inference error: {e}")
  135. return ""
  136. # ================= 引擎2: DDDDOCR =================
  137. class DddOcrEngine:
  138. def __init__(self):
  139. if HAS_DDDDOCR:
  140. self.ocr = ddddocr.DdddOcr(show_ad=False, beta=True)
  141. print("[DDDDOCR] Initialized successfully")
  142. self.ready = True
  143. else:
  144. print("[DDDDOCR] Library missing")
  145. self.ready = False
  146. def inference_bytes(self, image_bytes):
  147. """ 原有的 VFCode 识别逻辑 """
  148. if not self.ready:
  149. return "Error: ddddocr not installed"
  150. try:
  151. # 1. VF 专用预处理
  152. img_pil = advanced_denoise(image_bytes)
  153. # 2. 转 bytes 传给 ddddocr
  154. img_byte_arr = BytesIO()
  155. img_pil.save(img_byte_arr, format='PNG')
  156. processed_bytes = img_byte_arr.getvalue()
  157. # 3. 识别
  158. res = self.ocr.classification(processed_bytes)
  159. return res
  160. except Exception as e:
  161. print(f"[DDDDOCR] Inference error: {e}")
  162. return ""
  163. def inference_captcha(self, image_bytes):
  164. """
  165. [新增] 适配你提供的预处理逻辑
  166. 路径: /predict/visametric
  167. """
  168. if not self.ready:
  169. return "Error: ddddocr not installed"
  170. try:
  171. # 1. 打开图片
  172. image = Image.open(io.BytesIO(image_bytes))
  173. # 2. 自定义预处理: 灰度 -> 中值滤波 -> 二值化
  174. gray_img = image.convert("L").filter(ImageFilter.MedianFilter(size=3))
  175. binary_img = gray_img.point(lambda p: 255 if p > 128 else 0)
  176. # 3. 转 bytes 并识别
  177. with io.BytesIO() as img_buffer:
  178. binary_img.save(img_buffer, format="PNG")
  179. processed_bytes = img_buffer.getvalue()
  180. return self.ocr.classification(processed_bytes)
  181. except Exception as e:
  182. print(f"[DDDDOCR-Captcha] Inference error: {e}")
  183. return ""
  184. # ================= HTTP 处理 =================
  185. engines = {}
  186. class RequestHandler(BaseHTTPRequestHandler):
  187. def _send_response(self, status, content_type, content):
  188. self.send_response(status)
  189. self.send_header('Content-type', content_type)
  190. self.end_headers()
  191. self.wfile.write(content)
  192. def log_message(self, format, *args):
  193. return
  194. def do_POST(self):
  195. parsed_path = urlparse(self.path)
  196. path = parsed_path.path
  197. query_params = parse_qs(parsed_path.query)
  198. # 获取 Content-Length
  199. try:
  200. content_length = int(self.headers.get('Content-Length', 0))
  201. if content_length == 0:
  202. self._send_response(400, 'application/json', json.dumps({'code': 400, 'msg': 'Empty body'}).encode())
  203. return
  204. file_content = self.rfile.read(content_length)
  205. except Exception:
  206. self._send_response(400, 'application/json', json.dumps({'code': 400, 'msg': 'Read body failed'}).encode())
  207. return
  208. result_string = ""
  209. try:
  210. # === 路由 1: 原有的 VFCode 识别 ===
  211. if path == '/predict/bls':
  212. model_type = query_params.get('model', ['ddddocr'])[0]
  213. if model_type == 'ddddocr':
  214. if 'ddddocr' in engines:
  215. result_string = engines['ddddocr'].inference_bytes(file_content)
  216. else:
  217. result_string = "Error: ddddocr not available"
  218. else:
  219. if 'pytorch' in engines:
  220. result_string = engines['pytorch'].inference_bytes(file_content)
  221. else:
  222. result_string = "Error: pytorch model not available"
  223. print(f"[VFCode] [{model_type}] Result: {result_string}")
  224. # === 路由 2: 新增的通用 Captcha 识别 ===
  225. elif path == '/predict/visametric':
  226. if 'ddddocr' in engines:
  227. # 使用新增的预处理逻辑
  228. result_string = engines['ddddocr'].inference_captcha(file_content)
  229. else:
  230. result_string = "Error: ddddocr not available"
  231. print(f"[Captcha] Result: {result_string}")
  232. else:
  233. self._send_response(404, 'text/plain', b'Not Found')
  234. return
  235. # 返回成功响应
  236. response = {
  237. 'data': result_string,
  238. 'msg': "success",
  239. 'code': 200
  240. }
  241. self._send_response(200, 'application/json', json.dumps(response).encode())
  242. except Exception as e:
  243. traceback.print_exc()
  244. response = {'data': '', 'msg': 'failed', 'code': 500}
  245. self._send_response(500, 'application/json', json.dumps(response).encode())
  246. if __name__ == '__main__':
  247. MODEL_PATH = 'data/ctc.pth'
  248. PORT = 8085
  249. # 初始化 PyTorch 引擎
  250. pytorch_engine = PyTorchEngine(MODEL_PATH)
  251. if pytorch_engine.ready:
  252. engines['pytorch'] = pytorch_engine
  253. # 初始化 DDDDOCR 引擎
  254. ddd_engine = DddOcrEngine()
  255. if ddd_engine.ready:
  256. engines['ddddocr'] = ddd_engine
  257. server_address = ('0.0.0.0', PORT)
  258. httpd = HTTPServer(server_address, RequestHandler)
  259. print(f'OCR Server running on port {PORT}...')
  260. print(f'Routes available:')
  261. print(f' POST /predict/bls?model=ddddocr|pytorch')
  262. print(f' POST /predict/visametric (Uses specific preprocessing)')
  263. try:
  264. httpd.serve_forever()
  265. except KeyboardInterrupt:
  266. pass
  267. httpd.server_close()