predict_server.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. import os
  2. # 屏蔽 ONNX Runtime 的警告日志
  3. os.environ["ORT_LOGGING_LEVEL"] = "3"
  4. import json
  5. import string
  6. import socket
  7. import traceback
  8. from http.server import BaseHTTPRequestHandler, HTTPServer
  9. from io import BytesIO
  10. from collections import OrderedDict
  11. from urllib.parse import urlparse, parse_qs
  12. # 图像处理依赖
  13. import cv2
  14. import numpy as np
  15. from PIL import Image
  16. # 深度学习依赖
  17. import torch
  18. from torch import nn
  19. from torchvision import transforms
  20. # ddddocr 依赖
  21. try:
  22. import ddddocr
  23. HAS_DDDDOCR = True
  24. except ImportError:
  25. print("[WARNING] ddddocr not installed. Run 'pip install ddddocr'")
  26. HAS_DDDDOCR = False
  27. # ================= 核心优化:图像去噪 =================
  28. def advanced_denoise(image_bytes):
  29. """
  30. 针对 BLS 验证码的去噪流程:
  31. 1. 转灰度
  32. 2. 中值滤波 (关键:去除椒盐噪点)
  33. 3. 自适应二值化 (剥离彩色背景)
  34. 4. 连通域过滤 (去除残留的微小噪点)
  35. """
  36. try:
  37. # 1. 字节流转 OpenCV 格式
  38. nparr = np.frombuffer(image_bytes, np.uint8)
  39. img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
  40. # 2. 灰度化
  41. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  42. # 3. 中值滤波 (Median Blur) - 去除椒盐噪声的神器
  43. # ksize=3 表示 3x3 区域,能过滤掉独立的黑点,保留较粗的笔画
  44. gray_blur = cv2.medianBlur(gray, 3)
  45. # 4. 自适应二值化
  46. # 使用 Gaussian 方法,BlockSize=11, C=2 经验参数
  47. binary = cv2.adaptiveThreshold(
  48. gray_blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
  49. cv2.THRESH_BINARY, 11, 2
  50. )
  51. # 5. 连通域降噪 (Contour Filter)
  52. # 找到所有的黑色块(文字和残留噪点)
  53. # 注意:OpenCV findContours 找的是白色块,所以先反转
  54. contours, _ = cv2.findContours(255 - binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  55. # 创建一个纯白背景
  56. clean_img = np.ones(binary.shape, dtype="uint8") * 255
  57. for cnt in contours:
  58. area = cv2.contourArea(cnt)
  59. # 过滤逻辑:保留面积在 30 到 1000 像素之间的色块 (文字)
  60. # 小于 30 的通常是残留噪点,大于 1000 的可能是边框
  61. if 30 < area < 1000:
  62. cv2.drawContours(clean_img, [cnt], -1, 0, -1) # 在白底上画黑色文字
  63. # 6. 转回 PIL Image
  64. return Image.fromarray(clean_img)
  65. except Exception as e:
  66. print(f"[Denoise] Error: {e}")
  67. # 出错时回退到原始图片
  68. return Image.open(BytesIO(image_bytes))
  69. # ================= PyTorch 模型结构 (保持不变) =================
  70. class Model(nn.Module):
  71. def __init__(self, n_classes, input_shape=(3, 64, 128)):
  72. super(Model, self).__init__()
  73. self.input_shape = input_shape
  74. channels = [32, 64, 128, 256, 256]
  75. layers = [2, 2, 2, 2, 2]
  76. kernels = [3, 3, 3, 3, 3]
  77. pools = [2, 2, 2, 2, (2, 1)]
  78. modules = OrderedDict()
  79. def cba(name, in_channels, out_channels, kernel_size):
  80. modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
  81. padding=(1, 1) if kernel_size == 3 else 0)
  82. modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
  83. modules[f'relu{name}'] = nn.ReLU(inplace=True)
  84. last_channel = 3
  85. for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
  86. for layer in range(1, n_layer + 1):
  87. cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
  88. last_channel = n_channel
  89. modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
  90. modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
  91. self.cnn = nn.Sequential(modules)
  92. self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
  93. self.fc = nn.Linear(in_features=256, out_features=n_classes)
  94. def infer_features(self):
  95. x = torch.zeros((1,)+self.input_shape)
  96. x = self.cnn(x)
  97. x = x.reshape(x.shape[0], -1, x.shape[-1])
  98. return x.shape[1]
  99. def forward(self, x):
  100. x = self.cnn(x)
  101. x = x.reshape(x.shape[0], -1, x.shape[-1])
  102. x = x.permute(2, 0, 1)
  103. x, _ = self.lstm(x)
  104. x = self.fc(x)
  105. return x
  106. # ================= 引擎1: PyTorch =================
  107. class PyTorchEngine:
  108. def __init__(self, model_path):
  109. self.num_classes = 12
  110. self.characters = '-' + string.digits + '$'
  111. self.width = 150
  112. self.hight = 80
  113. self.model = Model(self.num_classes, input_shape=(3, self.hight, self.width))
  114. if os.path.exists(model_path):
  115. self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
  116. self.model.eval()
  117. print(f"[PyTorch] Model loaded successfully from {model_path}")
  118. self.ready = True
  119. else:
  120. print(f"[PyTorch] Warning: Model file not found at {model_path}")
  121. self.ready = False
  122. self.transforms_func = transforms.Compose([
  123. transforms.Resize((self.hight, self.width)),
  124. transforms.ToTensor()
  125. ])
  126. def decode(self, sequence):
  127. a = ''.join([self.characters[x] for x in sequence])
  128. s = []
  129. last = None
  130. for x in a:
  131. if x != last:
  132. s.append(x)
  133. last = x
  134. s2 = ''.join([x for x in s if x != self.characters[0]])
  135. return s2
  136. def inference_bytes(self, image_bytes):
  137. if not self.ready:
  138. return "Error: Model not loaded"
  139. try:
  140. # 使用高级去噪
  141. image = advanced_denoise(image_bytes)
  142. image = image.convert('RGB')
  143. if self.transforms_func is not None:
  144. image = self.transforms_func(image)
  145. with torch.no_grad():
  146. output = self.model(image.unsqueeze(0).cpu())
  147. output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
  148. predict_label = self.decode(output_argmax[0])
  149. return predict_label
  150. except Exception as e:
  151. print(f"[PyTorch] Inference error: {e}")
  152. return ""
  153. # ================= 引擎2: DDDDOCR (已优化) =================
  154. class DddOcrEngine:
  155. def __init__(self):
  156. if HAS_DDDDOCR:
  157. # show_ad=False 关闭广告, beta=True 启用旧版模型(通常对纯数字更稳)
  158. self.ocr = ddddocr.DdddOcr(show_ad=False, beta=True)
  159. print("[DDDDOCR] Initialized successfully")
  160. self.ready = True
  161. else:
  162. print("[DDDDOCR] Library missing")
  163. self.ready = False
  164. def inference_bytes(self, image_bytes):
  165. if not self.ready:
  166. return "Error: ddddocr not installed"
  167. try:
  168. # 1. 预处理:去噪、二值化、过滤
  169. img_pil = advanced_denoise(image_bytes)
  170. # 2. 转 bytes 传给 ddddocr
  171. img_byte_arr = BytesIO()
  172. img_pil.save(img_byte_arr, format='PNG')
  173. processed_bytes = img_byte_arr.getvalue()
  174. # 3. 识别
  175. res = self.ocr.classification(processed_bytes)
  176. return res
  177. except Exception as e:
  178. print(f"[DDDDOCR] Inference error: {e}")
  179. return ""
  180. # ================= HTTP 处理 =================
  181. engines = {}
  182. class RequestHandler(BaseHTTPRequestHandler):
  183. def _send_response(self, status, content_type, content):
  184. self.send_response(status)
  185. self.send_header('Content-type', content_type)
  186. self.end_headers()
  187. self.wfile.write(content)
  188. def log_message(self, format, *args):
  189. # 屏蔽 HTTP 请求日志,只打印识别结果
  190. return
  191. def do_POST(self):
  192. parsed_path = urlparse(self.path)
  193. path = parsed_path.path
  194. query_params = parse_qs(parsed_path.query)
  195. # 默认使用 ddddocr,因为加上去噪后效果通常好于未针对性训练的 pytorch 模型
  196. model_type = query_params.get('model', ['ddddocr'])[0]
  197. if path == '/predict/vfcode':
  198. try:
  199. content_length = int(self.headers.get('Content-Length', 0))
  200. if content_length == 0:
  201. self._send_response(400, 'application/json', json.dumps({'code': 400, 'msg': 'Empty body'}).encode())
  202. return
  203. file_content = self.rfile.read(content_length)
  204. result_string = ""
  205. if model_type == 'ddddocr':
  206. if 'ddddocr' in engines:
  207. result_string = engines['ddddocr'].inference_bytes(file_content)
  208. else:
  209. result_string = "Error: ddddocr not available"
  210. else:
  211. if 'pytorch' in engines:
  212. result_string = engines['pytorch'].inference_bytes(file_content)
  213. else:
  214. result_string = "Error: pytorch model not available"
  215. response = {
  216. 'data': result_string,
  217. 'msg': "success",
  218. 'code': 200
  219. }
  220. self._send_response(200, 'application/json', json.dumps(response).encode())
  221. # 打印简洁的识别日志
  222. print(f"[{model_type}] Result: {result_string}")
  223. except Exception as e:
  224. traceback.print_exc()
  225. response = {'data': '', 'msg': 'failed', 'code': 500}
  226. self._send_response(500, 'application/json', json.dumps(response).encode())
  227. else:
  228. self._send_response(404, 'text/plain', b'Not Found')
  229. if __name__ == '__main__':
  230. MODEL_PATH = 'data/ocr.pth'
  231. PORT = 8085
  232. # 1. PyTorch
  233. pytorch_engine = PyTorchEngine(MODEL_PATH)
  234. if pytorch_engine.ready:
  235. engines['pytorch'] = pytorch_engine
  236. # 2. ddddocr
  237. ddd_engine = DddOcrEngine()
  238. if ddd_engine.ready:
  239. engines['ddddocr'] = ddd_engine
  240. server_address = ('0.0.0.0', PORT)
  241. httpd = HTTPServer(server_address, RequestHandler)
  242. print(f'OCR Server running on port {PORT}...')
  243. print(f'Active engines: {list(engines.keys())}')
  244. try:
  245. httpd.serve_forever()
  246. except KeyboardInterrupt:
  247. pass
  248. httpd.server_close()