导入套件
import base64
import datetime
import hashlib
import time
from argparse import ArgumentParser
import multiprocessing
import cv2
import numpy as np
from flask import Flask
from flask import jsonify
from flask import request
from img_gray import process_img
from PIL import Image
import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
import torchvision.models as models
import pandas as pd
from R_model_load import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import numpy as np
import os
from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
模型初始化资料
2.1 资料内容
2.2 程序码
app = Flask(__name__)
# 队长email
CAPTAIN_EMAIL = '[email protected]'
# uuid加密
SALT = '1688'
# CPU运算(关闭GPU)
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
# 4个Model初始化
# Xception
model_Xception = None
# InceptionResNetV2
model_V2 = None
# Densenet201
model_swa = None
# R_SVM_model
model_R = None
# 接收的图片Log档
file1 = open('./pic_base64.txt', 'a')
# 官方800字清单
words_path = r'./800_words.txt'
file2 = open(words_path, 'rt', encoding='Big5')
labels = list(file2.read())
# 模型组合之权重与阈值表
# 载入表格
weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
# DenseNet权重
weight_swa = weight_df['wei_ex6'].values
# InceptionResNetV2权重
weight_V2 = weight_df['wei_ex5'].values
# Xception权重
weight_Xception = weight_df['wei_3'].values
API初始化
3.1 before_first_request:在处理第一个request前,先执行API初始化,用以载入模型。
3.2 程序码
@app.before_first_request
def init():
# Xception
global model_Xception
model_Xception = load_model('./Xception_retrained_v2.h5')
# InceptionResNetV2
global model_V2
model_V2 = load_model('./InceptionResNetV2.h5')
# DenseNet201
global model_swa
model_densenet = models.densenet201(num_classes=800)
model_path = './swa_densenet201.pth'
model_fang = model_densenet
model_swa = AveragedModel(model_fang)
model_swa.eval()
model_swa.load_state_dict(torch.load(model_path,
map_location=torch.device('cpu')))
# SVM模型
global model_R
MODEL_PATH = "./model_svm_v3"
model_R = Model().load(MODEL_PATH)
print('====================API初始化完成init====================')
产出server_uuid
def generate_server_uuid(input_string):
s = hashlib.sha256()
data = (input_string + SALT).encode("utf-8")
s.update(data)
server_uuid = s.hexdigest()
return server_uuid
检查预测结果是否为字串:供後续输出预测结果之前,判定资料型态。
def _check_datatype_to_string(prediction):
if isinstance(prediction, str):
return True
raise TypeError('Prediction is not in string type.')
将接收到的图片转换格式
6.1 流程
6.2 程序码
def base64_to_binary_for_cv2(image_64_encoded):
# base64转numpy
img_base64_binary = image_64_encoded.encode("utf-8")
img_binary = base64.b64decode(img_base64_binary)
image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
cv2.IMREAD_COLOR)
# 图片预处理
image = process_img(image)
image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
image_for_tensorflow = np.asarray(image)
# 将接收的图片,储存到Log档
file1.write(image_64_encoded + '\n')
# Xception之input图片格式
image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
interpolation=cv2.INTER_CUBIC)
image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
image_for_Xception = image_for_Xception / 255
# InceptionResNetV2之input图片格式
image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
interpolation=cv2.INTER_CUBIC)
image_for_V2 = np.expand_dims(image_for_V2, axis=0)
image_for_V2 = image_for_V2 / 255
# DenseNet201之input图片格式
transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
contrast=(6, 6), saturation=(1, 1),
hue=(-0.1, 0.1)), ToTensor(),
Normalize((0.5,), (0.5,))])
image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
image_for_swa = transforms(image_for_swa)
return image_for_Xception,image_for_V2,image_for_swa
模型辨识手写中文字
7.1 流程
7.2 程序码
def predict(image_for_Xception,image_for_V2,image_for_swa):
# InceptionResNetV2 predict的机率加权
# 机率向量
pred_V2 = model_V2.predict(image_for_V2)[0]
# 乘上权重的新机率向量
new_V2_prob = pred_V2 * weight_V2
# Xception predict的机率加权
# 机率向量
pred_Xception = model_Xception.predict(image_for_Xception)[0]
# 乘上权重的新机率向量
new_Xception_prob = pred_Xception * weight_Xception
# DenseNet201 predict的机率加权
img = image_for_swa.view(1, 3, 80, 80)
output = model_swa(img)
output = output.view(-1, 800)
output_prob = nn.functional.softmax(output, dim=1)
# 机率向量
output_prob_np = output_prob.cpu().detach().numpy()[0]
# 乘上权重取得新机率向量
new_swa_prob = output_prob_np * weight_swa
# 三个模型向量相加取得新的向量,判定手写中文字
new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
max_prob = np.max(new_prob)
pred_word = np.argmax(new_prob)
# 读取该手写中文字的阈值
judge = labels[pred_word]
mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
# 判断阈值
if max_prob < mean_prob:
prediction = "isnull"
else:
# 考虑加上SVM模型
new_prob_2dim = new_prob[np.newaxis,:]
# 丢入Rmodel预测是否为isnull,1为800字内,2为isnull
pred = model_R.predict(new_prob_2dim)
if pred == 2:
prediction = "isnull"
else:
final_answer = np.argmax(new_prob)
prediction = labels[final_answer]
# 检查预测结果是否为字串
if _check_datatype_to_string(prediction):
return prediction
API服务(inference 资料传输格式:json)
8.1 流程
8.2 程序码
@app.route('/inference', methods=['POST'])
def inference():
# 接收用户request
data = request.get_json(force=True)
# 取image base64 encoded,并以cv2转换格式
image_64_encoded = data['image']
image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
# 产出server_uuid
t = datetime.datetime.now()
ts = str(int(t.utcnow().timestamp()))
server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
# 记录API错误log
try:
answer = predict(image_for_Xception,
image_for_V2,
image_for_swa)
except TypeError as type_error:
raise type_error
except Exception as e:
raise e
server_timestamp = time.time()
# 回传预测结果给用户
return jsonify({'esun_uuid': data['esun_uuid'],
'server_uuid': server_uuid,
'answer': answer,
'server_timestamp': server_timestamp})
if __name__ == "__main__":
arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
' [--port <port>] [--help]')
arg_parser.add_argument('-p', '--port', default=8080, help='port')
arg_parser.add_argument('-d', '--debug', default=True, help='debug')
options = arg_parser.parse_args()
app.run(host='0.0.0.0', port=options.port, debug=options.debug)
启用API服务
9.1 如何启用API服务
9.2 成功启用API服务(如下图)
让我们继续看下去...
有句话说,没用过 unmarshal 就等於没写过 go func Unmarshal(data [...
Aloha!我是少女人妻 Uerica!这个周末朋友要求婚了~朋友前阵子喝了一点然後问我婚姻的感觉...
为了要了解企业流程管理(BPM),很多人上网搜寻到的文章,常常都有些八股,或是看不到想要了解的部分,...
因缘际会需要串某个 JSON API ,然後跟加密这方面实在是不熟,而对方给的范例又不是 Pytho...
前言 昨天的文章讲完前端 Nginx 的写法後,今天就要来进入後端的写法啦!在昨天的小结提到後端的写...