| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- import os
- # 屏蔽 ONNX Runtime 的警告日志
- os.environ["ORT_LOGGING_LEVEL"] = "3"
- import json
- import string
- import socket
- import traceback
- from http.server import BaseHTTPRequestHandler, HTTPServer
- from io import BytesIO
- from collections import OrderedDict
- from urllib.parse import urlparse, parse_qs
- # 图像处理依赖
- import cv2
- import numpy as np
- from PIL import Image
- # 深度学习依赖
- import torch
- from torch import nn
- from torchvision import transforms
- # ddddocr 依赖
- try:
- import ddddocr
- HAS_DDDDOCR = True
- except ImportError:
- print("[WARNING] ddddocr not installed. Run 'pip install ddddocr'")
- HAS_DDDDOCR = False
- # ================= 核心优化:图像去噪 =================
- def advanced_denoise(image_bytes):
- """
- 针对 BLS 验证码的去噪流程:
- 1. 转灰度
- 2. 中值滤波 (关键:去除椒盐噪点)
- 3. 自适应二值化 (剥离彩色背景)
- 4. 连通域过滤 (去除残留的微小噪点)
- """
- try:
- # 1. 字节流转 OpenCV 格式
- nparr = np.frombuffer(image_bytes, np.uint8)
- img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
-
- # 2. 灰度化
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
-
- # 3. 中值滤波 (Median Blur) - 去除椒盐噪声的神器
- # ksize=3 表示 3x3 区域,能过滤掉独立的黑点,保留较粗的笔画
- gray_blur = cv2.medianBlur(gray, 3)
-
- # 4. 自适应二值化
- # 使用 Gaussian 方法,BlockSize=11, C=2 经验参数
- binary = cv2.adaptiveThreshold(
- gray_blur, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
- cv2.THRESH_BINARY, 11, 2
- )
-
- # 5. 连通域降噪 (Contour Filter)
- # 找到所有的黑色块(文字和残留噪点)
- # 注意:OpenCV findContours 找的是白色块,所以先反转
- contours, _ = cv2.findContours(255 - binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
-
- # 创建一个纯白背景
- clean_img = np.ones(binary.shape, dtype="uint8") * 255
-
- for cnt in contours:
- area = cv2.contourArea(cnt)
- # 过滤逻辑:保留面积在 30 到 1000 像素之间的色块 (文字)
- # 小于 30 的通常是残留噪点,大于 1000 的可能是边框
- if 30 < area < 1000:
- cv2.drawContours(clean_img, [cnt], -1, 0, -1) # 在白底上画黑色文字
- # 6. 转回 PIL Image
- return Image.fromarray(clean_img)
-
- except Exception as e:
- print(f"[Denoise] Error: {e}")
- # 出错时回退到原始图片
- return Image.open(BytesIO(image_bytes))
- # ================= PyTorch 模型结构 (保持不变) =================
- class Model(nn.Module):
- def __init__(self, n_classes, input_shape=(3, 64, 128)):
- super(Model, self).__init__()
- self.input_shape = input_shape
- channels = [32, 64, 128, 256, 256]
- layers = [2, 2, 2, 2, 2]
- kernels = [3, 3, 3, 3, 3]
- pools = [2, 2, 2, 2, (2, 1)]
- modules = OrderedDict()
-
- def cba(name, in_channels, out_channels, kernel_size):
- modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
- padding=(1, 1) if kernel_size == 3 else 0)
- modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
- modules[f'relu{name}'] = nn.ReLU(inplace=True)
-
- last_channel = 3
- for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
- for layer in range(1, n_layer + 1):
- cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
- last_channel = n_channel
- modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
- modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
-
- self.cnn = nn.Sequential(modules)
- self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
- self.fc = nn.Linear(in_features=256, out_features=n_classes)
-
- def infer_features(self):
- x = torch.zeros((1,)+self.input_shape)
- x = self.cnn(x)
- x = x.reshape(x.shape[0], -1, x.shape[-1])
- return x.shape[1]
- def forward(self, x):
- x = self.cnn(x)
- x = x.reshape(x.shape[0], -1, x.shape[-1])
- x = x.permute(2, 0, 1)
- x, _ = self.lstm(x)
- x = self.fc(x)
- return x
- # ================= 引擎1: PyTorch =================
- class PyTorchEngine:
- def __init__(self, model_path):
- self.num_classes = 12
- self.characters = '-' + string.digits + '$'
- self.width = 150
- self.hight = 80
- self.model = Model(self.num_classes, input_shape=(3, self.hight, self.width))
-
- if os.path.exists(model_path):
- self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
- self.model.eval()
- print(f"[PyTorch] Model loaded successfully from {model_path}")
- self.ready = True
- else:
- print(f"[PyTorch] Warning: Model file not found at {model_path}")
- self.ready = False
- self.transforms_func = transforms.Compose([
- transforms.Resize((self.hight, self.width)),
- transforms.ToTensor()
- ])
-
- def decode(self, sequence):
- a = ''.join([self.characters[x] for x in sequence])
- s = []
- last = None
- for x in a:
- if x != last:
- s.append(x)
- last = x
- s2 = ''.join([x for x in s if x != self.characters[0]])
- return s2
- def inference_bytes(self, image_bytes):
- if not self.ready:
- return "Error: Model not loaded"
- try:
- # 使用高级去噪
- image = advanced_denoise(image_bytes)
- image = image.convert('RGB')
-
- if self.transforms_func is not None:
- image = self.transforms_func(image)
- with torch.no_grad():
- output = self.model(image.unsqueeze(0).cpu())
-
- output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
- predict_label = self.decode(output_argmax[0])
- return predict_label
- except Exception as e:
- print(f"[PyTorch] Inference error: {e}")
- return ""
- # ================= 引擎2: DDDDOCR (已优化) =================
- class DddOcrEngine:
- def __init__(self):
- if HAS_DDDDOCR:
- # show_ad=False 关闭广告, beta=True 启用旧版模型(通常对纯数字更稳)
- self.ocr = ddddocr.DdddOcr(show_ad=False, beta=True)
- print("[DDDDOCR] Initialized successfully")
- self.ready = True
- else:
- print("[DDDDOCR] Library missing")
- self.ready = False
- def inference_bytes(self, image_bytes):
- if not self.ready:
- return "Error: ddddocr not installed"
- try:
- # 1. 预处理:去噪、二值化、过滤
- img_pil = advanced_denoise(image_bytes)
-
- # 2. 转 bytes 传给 ddddocr
- img_byte_arr = BytesIO()
- img_pil.save(img_byte_arr, format='PNG')
- processed_bytes = img_byte_arr.getvalue()
-
- # 3. 识别
- res = self.ocr.classification(processed_bytes)
- return res
- except Exception as e:
- print(f"[DDDDOCR] Inference error: {e}")
- return ""
- # ================= HTTP 处理 =================
- engines = {}
- class RequestHandler(BaseHTTPRequestHandler):
- def _send_response(self, status, content_type, content):
- self.send_response(status)
- self.send_header('Content-type', content_type)
- self.end_headers()
- self.wfile.write(content)
- def log_message(self, format, *args):
- # 屏蔽 HTTP 请求日志,只打印识别结果
- return
- def do_POST(self):
- parsed_path = urlparse(self.path)
- path = parsed_path.path
- query_params = parse_qs(parsed_path.query)
-
- # 默认使用 ddddocr,因为加上去噪后效果通常好于未针对性训练的 pytorch 模型
- model_type = query_params.get('model', ['ddddocr'])[0]
- if path == '/predict/vfcode':
- try:
- content_length = int(self.headers.get('Content-Length', 0))
- if content_length == 0:
- self._send_response(400, 'application/json', json.dumps({'code': 400, 'msg': 'Empty body'}).encode())
- return
- file_content = self.rfile.read(content_length)
- result_string = ""
- if model_type == 'ddddocr':
- if 'ddddocr' in engines:
- result_string = engines['ddddocr'].inference_bytes(file_content)
- else:
- result_string = "Error: ddddocr not available"
- else:
- if 'pytorch' in engines:
- result_string = engines['pytorch'].inference_bytes(file_content)
- else:
- result_string = "Error: pytorch model not available"
-
- response = {
- 'data': result_string,
- 'msg': "success",
- 'code': 200
- }
- self._send_response(200, 'application/json', json.dumps(response).encode())
-
- # 打印简洁的识别日志
- print(f"[{model_type}] Result: {result_string}")
- except Exception as e:
- traceback.print_exc()
- response = {'data': '', 'msg': 'failed', 'code': 500}
- self._send_response(500, 'application/json', json.dumps(response).encode())
- else:
- self._send_response(404, 'text/plain', b'Not Found')
- if __name__ == '__main__':
- MODEL_PATH = 'data/ctc.pth'
- PORT = 8085
-
- # 1. PyTorch
- pytorch_engine = PyTorchEngine(MODEL_PATH)
- if pytorch_engine.ready:
- engines['pytorch'] = pytorch_engine
-
- # 2. ddddocr
- ddd_engine = DddOcrEngine()
- if ddd_engine.ready:
- engines['ddddocr'] = ddd_engine
-
- server_address = ('0.0.0.0', PORT)
- httpd = HTTPServer(server_address, RequestHandler)
- print(f'OCR Server running on port {PORT}...')
- print(f'Active engines: {list(engines.keys())}')
-
- try:
- httpd.serve_forever()
- except KeyboardInterrupt:
- pass
- httpd.server_close()
|