predict_server.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # predict_server.py
  2. import os
  3. import json
  4. import string
  5. import socket
  6. import traceback
  7. from http.server import BaseHTTPRequestHandler, HTTPServer
  8. from io import BytesIO
  9. from collections import OrderedDict
  10. # 深度学习依赖
  11. import torch
  12. from torch import nn
  13. from torchvision import transforms
  14. from PIL import Image
  15. # ================= 定义模型结构 (保持不变) =================
  16. class Model(nn.Module):
  17. def __init__(self, n_classes, input_shape=(3, 64, 128)):
  18. super(Model, self).__init__()
  19. self.input_shape = input_shape
  20. channels = [32, 64, 128, 256, 256]
  21. layers = [2, 2, 2, 2, 2]
  22. kernels = [3, 3, 3, 3, 3]
  23. pools = [2, 2, 2, 2, (2, 1)]
  24. modules = OrderedDict()
  25. def cba(name, in_channels, out_channels, kernel_size):
  26. modules[f'conv{name}'] = nn.Conv2d(in_channels, out_channels, kernel_size,
  27. padding=(1, 1) if kernel_size == 3 else 0)
  28. modules[f'bn{name}'] = nn.BatchNorm2d(out_channels)
  29. modules[f'relu{name}'] = nn.ReLU(inplace=True)
  30. last_channel = 3
  31. for block, (n_channel, n_layer, n_kernel, k_pool) in enumerate(zip(channels, layers, kernels, pools)):
  32. for layer in range(1, n_layer + 1):
  33. cba(f'{block+1}{layer}', last_channel, n_channel, n_kernel)
  34. last_channel = n_channel
  35. modules[f'pool{block + 1}'] = nn.MaxPool2d(k_pool)
  36. modules[f'dropout'] = nn.Dropout(0.25, inplace=True)
  37. self.cnn = nn.Sequential(modules)
  38. self.lstm = nn.LSTM(input_size=self.infer_features(), hidden_size=128, num_layers=2, bidirectional=True)
  39. self.fc = nn.Linear(in_features=256, out_features=n_classes)
  40. def infer_features(self):
  41. x = torch.zeros((1,)+self.input_shape)
  42. x = self.cnn(x)
  43. x = x.reshape(x.shape[0], -1, x.shape[-1])
  44. return x.shape[1]
  45. def forward(self, x):
  46. x = self.cnn(x)
  47. x = x.reshape(x.shape[0], -1, x.shape[-1])
  48. x = x.permute(2, 0, 1)
  49. x, _ = self.lstm(x)
  50. x = self.fc(x)
  51. return x
  52. # ================= 推理类 =================
  53. class DeployModel:
  54. def __init__(self, model_path):
  55. self.num_classes = 12
  56. self.characters = '-' + string.digits + '$'
  57. self.width = 150
  58. self.hight = 80
  59. self.model = Model(self.num_classes, input_shape=(3, self.hight, self.width))
  60. if os.path.exists(model_path):
  61. self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
  62. self.model.eval()
  63. print(f"Model loaded successfully from {model_path}")
  64. else:
  65. raise FileNotFoundError(f"Model file not found: {model_path}")
  66. self.transforms_func = transforms.Compose([
  67. transforms.Resize((self.hight, self.width)),
  68. transforms.ToTensor()
  69. ])
  70. def decode(self, sequence):
  71. a = ''.join([self.characters[x] for x in sequence])
  72. s = []
  73. last = None
  74. for x in a:
  75. if x != last:
  76. s.append(x)
  77. last = x
  78. s2 = ''.join([x for x in s if x != self.characters[0]])
  79. return s2
  80. def inference_bytes(self, image_bytes):
  81. try:
  82. image = Image.open(BytesIO(image_bytes))
  83. if image.mode == 'RGBA':
  84. image = image.convert('RGB')
  85. if self.transforms_func is not None:
  86. image = self.transforms_func(image)
  87. with torch.no_grad():
  88. output = self.model(image.unsqueeze(0).cpu())
  89. output_argmax = output.detach().permute(1, 0, 2).argmax(dim=-1)
  90. predict_label = self.decode(output_argmax[0])
  91. return predict_label
  92. except Exception as e:
  93. print(f"Inference error: {e}")
  94. return ""
  95. # ================= HTTP 处理 =================
  96. # 全局模型实例
  97. deploy_model = None
  98. class RequestHandler(BaseHTTPRequestHandler):
  99. def _send_response(self, status, content_type, content):
  100. self.send_response(status)
  101. self.send_header('Content-type', content_type)
  102. self.end_headers()
  103. self.wfile.write(content)
  104. def do_POST(self):
  105. if self.path == '/predict/vfcode':
  106. try:
  107. # 获取内容长度
  108. content_length = int(self.headers.get('Content-Length', 0))
  109. if content_length == 0:
  110. self._send_response(400, 'application/json', json.dumps({'code': 400, 'msg': 'Empty body'}).encode())
  111. return
  112. # 直接读取 Raw Binary 数据 (简化通信,避免 multipart 解析问题)
  113. file_content = self.rfile.read(content_length)
  114. # 推理
  115. result_string = deploy_model.inference_bytes(file_content)
  116. response = {
  117. 'data': result_string,
  118. 'msg': "success",
  119. 'code': 200
  120. }
  121. self._send_response(200, 'application/json', json.dumps(response).encode())
  122. print(f"Processed request. Result: {result_string}")
  123. except Exception as e:
  124. traceback.print_exc()
  125. response = {'data': '', 'msg': 'failed', 'code': 500}
  126. self._send_response(500, 'application/json', json.dumps(response).encode())
  127. else:
  128. self._send_response(404, 'text/plain', b'Not Found')
  129. if __name__ == '__main__':
  130. # 配置区
  131. MODEL_PATH = 'data/ocr.pth'
  132. PORT = 8085
  133. # 启动
  134. if not os.path.exists(MODEL_PATH):
  135. print(f"[ERROR] 请确保模型文件存在: {MODEL_PATH}")
  136. exit(1)
  137. deploy_model = DeployModel(MODEL_PATH)
  138. server_address = ('0.0.0.0', PORT) # 监听所有接口
  139. httpd = HTTPServer(server_address, RequestHandler)
  140. print(f'OCR Server running on port {PORT}...')
  141. try:
  142. httpd.serve_forever()
  143. except KeyboardInterrupt:
  144. pass
  145. httpd.server_close()