| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import os
- # 屏蔽 ONNX Runtime 警告
- os.environ["ORT_LOGGING_LEVEL"] = "3"
- import string
- import io
- import torch
- import ddddocr
- from torch import nn
- from collections import OrderedDict
- from torchvision import transforms
- from PIL import Image, ImageFilter
- from io import BytesIO
- # ================= PyTorch 模型结构 (保留原有逻辑) =================
- class Model(nn.Module):
- def __init__(self, n_classes, input_shape=(3, 64, 128)):
- super(Model, self).__init__()
- self.input_shape = input_shape
- channels = [32, 64, 128, 256, 256]
- layers = [2, 2, 2, 2, 2]
- kernels = [3, 3, 3, 3, 3]
- pools = [2, 2, 2, 2, (2, 1)]
- modules = OrderedDict()
-
- def cba(name, in_channels, out_channels, kernel_size):
- modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
- padding=(1, 1) if kernel_size == 3 else 0)
- modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
- modules[f'relu{name}'] = nn.ReLU(inplace=True)
-
- last_channel = 3
- for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
- for layer in range(1, n_layer + 1):
- cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
- last_channel = n_channel
- modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
- modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
-
- self.cnn = nn.Sequential(modules)
- self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
- self.fc = nn.Linear(in_features=256, out_features=n_classes)
-
- def infer_features(self):
- x = torch.zeros((1,)+self.input_shape)
- x = self.cnn(x)
- x = x.reshape(x.shape[0], -1, x.shape[-1])
- return x.shape[1]
- def forward(self, x):
- x = self.cnn(x)
- x = x.reshape(x.shape[0], -1, x.shape[-1])
- x = x.permute(2, 0, 1)
- x, _ = self.lstm(x)
- x = self.fc(x)
- return x
- # ================= 引擎封装 =================
- class PyTorchEngine:
- def __init__(self, model_path):
- self.num_classes = 12
- self.characters = '-' + string.digits + '$'
- self.width = 150
- self.hight = 80
- self.model = Model(self.num_classes, input_shape=(3, self.hight, self.width))
- self.ready = False
- self.transforms_func = transforms.Compose([
- transforms.Resize((self.hight, self.width)),
- transforms.ToTensor()
- ])
- if os.path.exists(model_path):
- self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
- self.model.eval()
- self.ready = True
- print(f"[PyTorch] Loaded: {model_path}")
- def decode(self, sequence):
- a = ''.join([self.characters[x] for x in sequence])
- s = []
- last = None
- for x in a:
- if x != last:
- s.append(x)
- last = x
- return ''.join([x for x in s if x != self.characters[0]])
- def inference_bytes(self, image_bytes):
- if not self.ready: return "Error: Model not loaded"
- try:
- image = Image.open(BytesIO(image_bytes)).convert('RGB')
- image = self.transforms_func(image)
- with torch.no_grad():
- output = self.model(image.unsqueeze(0).cpu())
- output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
- return self.decode(output_argmax[0])
- except Exception as e:
- return f"Error: {str(e)}"
- class DddOcrEngine:
- def __init__(self):
- self.ocr = ddddocr.DdddOcr(show_ad=False, beta=True)
-
- def inference_bytes(self, image_bytes):
- try:
- return self.ocr.classification(image_bytes)
- except Exception as e:
- return f"Error: {e}"
- def inference_captcha(self, image_bytes):
- try:
- image = Image.open(io.BytesIO(image_bytes))
- gray_img = image.convert("L").filter(ImageFilter.MedianFilter(size=3))
- binary_img = gray_img.point(lambda p: 255 if p > 128 else 0)
- buf = io.BytesIO()
- binary_img.save(buf, format="PNG")
- return self.ocr.classification(buf.getvalue())
- except Exception as e:
- return f"Error: {e}"
|