ocr_engine.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import os
  2. # 屏蔽 ONNX Runtime 警告
  3. os.environ["ORT_LOGGING_LEVEL"] = "3"
  4. import string
  5. import io
  6. import torch
  7. import ddddocr
  8. from torch import nn
  9. from collections import OrderedDict
  10. from torchvision import transforms
  11. from PIL import Image, ImageFilter
  12. from io import BytesIO
  13. # ================= PyTorch 模型结构 (保留原有逻辑) =================
  14. class Model(nn.Module):
  15. def __init__(self, n_classes, input_shape=(3, 64, 128)):
  16. super(Model, self).__init__()
  17. self.input_shape = input_shape
  18. channels = [32, 64, 128, 256, 256]
  19. layers = [2, 2, 2, 2, 2]
  20. kernels = [3, 3, 3, 3, 3]
  21. pools = [2, 2, 2, 2, (2, 1)]
  22. modules = OrderedDict()
  23. def cba(name, in_channels, out_channels, kernel_size):
  24. modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
  25. padding=(1, 1) if kernel_size == 3 else 0)
  26. modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
  27. modules[f'relu{name}'] = nn.ReLU(inplace=True)
  28. last_channel = 3
  29. for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
  30. for layer in range(1, n_layer + 1):
  31. cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
  32. last_channel = n_channel
  33. modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
  34. modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
  35. self.cnn = nn.Sequential(modules)
  36. self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
  37. self.fc = nn.Linear(in_features=256, out_features=n_classes)
  38. def infer_features(self):
  39. x = torch.zeros((1,)+self.input_shape)
  40. x = self.cnn(x)
  41. x = x.reshape(x.shape[0], -1, x.shape[-1])
  42. return x.shape[1]
  43. def forward(self, x):
  44. x = self.cnn(x)
  45. x = x.reshape(x.shape[0], -1, x.shape[-1])
  46. x = x.permute(2, 0, 1)
  47. x, _ = self.lstm(x)
  48. x = self.fc(x)
  49. return x
  50. # ================= 引擎封装 =================
  51. class PyTorchEngine:
  52. def __init__(self, model_path):
  53. self.num_classes = 12
  54. self.characters = '-' + string.digits + '$'
  55. self.width = 150
  56. self.hight = 80
  57. self.model = Model(self.num_classes, input_shape=(3, self.hight, self.width))
  58. self.ready = False
  59. self.transforms_func = transforms.Compose([
  60. transforms.Resize((self.hight, self.width)),
  61. transforms.ToTensor()
  62. ])
  63. if os.path.exists(model_path):
  64. self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
  65. self.model.eval()
  66. self.ready = True
  67. print(f"[PyTorch] Loaded: {model_path}")
  68. def decode(self, sequence):
  69. a = ''.join([self.characters[x] for x in sequence])
  70. s = []
  71. last = None
  72. for x in a:
  73. if x != last:
  74. s.append(x)
  75. last = x
  76. return ''.join([x for x in s if x != self.characters[0]])
  77. def inference_bytes(self, image_bytes):
  78. if not self.ready: return "Error: Model not loaded"
  79. try:
  80. image = Image.open(BytesIO(image_bytes)).convert('RGB')
  81. image = self.transforms_func(image)
  82. with torch.no_grad():
  83. output = self.model(image.unsqueeze(0).cpu())
  84. output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
  85. return self.decode(output_argmax[0])
  86. except Exception as e:
  87. return f"Error: {str(e)}"
  88. class DddOcrEngine:
  89. def __init__(self):
  90. self.ocr = ddddocr.DdddOcr(show_ad=False, beta=True)
  91. def inference_bytes(self, image_bytes):
  92. try:
  93. return self.ocr.classification(image_bytes)
  94. except Exception as e:
  95. return f"Error: {e}"
  96. def inference_captcha(self, image_bytes):
  97. try:
  98. image = Image.open(io.BytesIO(image_bytes))
  99. gray_img = image.convert("L").filter(ImageFilter.MedianFilter(size=3))
  100. binary_img = gray_img.point(lambda p: 255 if p > 128 else 0)
  101. buf = io.BytesIO()
  102. binary_img.save(buf, format="PNG")
  103. return self.ocr.classification(buf.getvalue())
  104. except Exception as e:
  105. return f"Error: {e}"