import os # 屏蔽 ONNX Runtime 警告 os.environ["ORT_LOGGING_LEVEL"] = "3" import string import io import torch import ddddocr from torch import nn from collections import OrderedDict from torchvision import transforms from PIL import Image, ImageFilter from io import BytesIO # ================= PyTorch 模型结构 (保留原有逻辑) ================= class Model(nn.Module): def __init__(self, n_classes, input_shape=(3, 64, 128)): super(Model, self).__init__() self.input_shape = input_shape channels = [32, 64, 128, 256, 256] layers = [2, 2, 2, 2, 2] kernels = [3, 3, 3, 3, 3] pools = [2, 2, 2, 2, (2, 1)] modules = OrderedDict() def cba(name, in_channels, out_channels, kernel_size): modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size, padding=(1, 1) if kernel_size == 3 else 0) modules[f'bn{name}'] = nn.BatchNorm2d(out_channels) modules[f'relu{name}'] = nn.ReLU(inplace=True) last_channel = 3 for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)): for layer in range(1, n_layer + 1): cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel) last_channel = n_channel modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool) modules[f'dropout'] = nn.Dropout(0.25, inplace=True) self.cnn = nn.Sequential(modules) self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True) self.fc = nn.Linear(in_features=256, out_features=n_classes) def infer_features(self): x = torch.zeros((1,)+self.input_shape) x = self.cnn(x) x = x.reshape(x.shape[0], -1, x.shape[-1]) return x.shape[1] def forward(self, x): x = self.cnn(x) x = x.reshape(x.shape[0], -1, x.shape[-1]) x = x.permute(2, 0, 1) x, _ = self.lstm(x) x = self.fc(x) return x # ================= 引擎封装 ================= class PyTorchEngine: def __init__(self, model_path): self.num_classes = 12 self.characters = '-' + string.digits + '$' self.width = 150 self.hight = 80 self.model = Model(self.num_classes, input_shape=(3, self.hight, self.width)) self.ready = False self.transforms_func = transforms.Compose([ transforms.Resize((self.hight, self.width)), transforms.ToTensor() ]) if os.path.exists(model_path): self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) self.model.eval() self.ready = True print(f"[PyTorch] Loaded: {model_path}") def decode(self, sequence): a = ''.join([self.characters[x] for x in sequence]) s = [] last = None for x in a: if x != last: s.append(x) last = x return ''.join([x for x in s if x != self.characters[0]]) def inference_bytes(self, image_bytes): if not self.ready: return "Error: Model not loaded" try: image = Image.open(BytesIO(image_bytes)).convert('RGB') image = self.transforms_func(image) with torch.no_grad(): output = self.model(image.unsqueeze(0).cpu()) output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1) return self.decode(output_argmax[0]) except Exception as e: return f"Error: {str(e)}" class DddOcrEngine: def __init__(self): self.ocr = ddddocr.DdddOcr(show_ad=False, beta=True) def inference_bytes(self, image_bytes): try: return self.ocr.classification(image_bytes) except Exception as e: return f"Error: {e}" def inference_captcha(self, image_bytes): try: image = Image.open(io.BytesIO(image_bytes)) gray_img = image.convert("L").filter(ImageFilter.MedianFilter(size=3)) binary_img = gray_img.point(lambda p: 255 if p > 128 else 0) buf = io.BytesIO() binary_img.save(buf, format="PNG") return self.ocr.classification(buf.getvalue()) except Exception as e: return f"Error: {e}"