import base64
import datetime
import hashlib
import time
from argparse import ArgumentParser
import multiprocessing
import cv2
import numpy as np
from flask import Flask
from flask import jsonify
from flask import request
from img_gray import process_img
from PIL import Image
import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
import torchvision.models as models
import pandas as pd
from R_model_load import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import numpy as np
import os
from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
初始化
程序码
app = Flask(__name__)
# 队长email
CAPTAIN_EMAIL = '[email protected]'
# uuid加密
SALT = '1688'
# CPU运算(关闭GPU)
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
# 4个Model初始化
# Xception
model_Xception = None
# InceptionResNetV2
model_V2 = None
# Densenet201
model_swa = None
# R_SVM_model
model_R = None
# 接收的图片Log档
file1 = open('./pic_base64.txt', 'a')
# 官方800字清单
words_path = r'./800_words.txt'
file2 = open(words_path, 'rt', encoding='Big5')
labels = list(file2.read())
# 模型组合之权重与阈值表
# 载入表格
weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
# DenseNet权重
weight_swa = weight_df['wei_ex6'].values
# InceptionResNetV2权重
weight_V2 = weight_df['wei_ex5'].values
# Xception权重
weight_Xception = weight_df['wei_3'].values
API初始化
before_first_request:在处理第一个request前,先执行API初始化,用以载入模型。使用此装饰器的原因:
程序码
@app.before_first_request
def init():
# Xception
global model_Xception
model_Xception = load_model('./Xception_retrained_v2.h5')
# InceptionResNetV2
global model_V2
model_V2 = load_model('./InceptionResNetV2.h5')
# DenseNet201
global model_swa
model_densenet = models.densenet201(num_classes=800)
model_path = './swa_densenet201.pth'
model_fang = model_densenet
model_swa = AveragedModel(model_fang)
model_swa.eval()
model_swa.load_state_dict(torch.load(model_path,
map_location=torch.device('cpu')))
# SVM模型
global model_R
MODEL_PATH = "./model_svm_v3"
model_R = Model().load(MODEL_PATH)
print('====================API初始化完成init====================')
产出server_uuid
def generate_server_uuid(input_string):
s = hashlib.sha256()
data = (input_string + SALT).encode("utf-8")
s.update(data)
server_uuid = s.hexdigest()
return server_uuid
def _check_datatype_to_string(prediction):
if isinstance(prediction, str):
return True
raise TypeError('Prediction is not in string type.')
将接收到的图片转换格式
将base64编码转换成numpy格式。
将图片去杂讯并转换成灰阶。
纪录比赛图片样本,存入log档,供後续改善模型之用。
将图片转换成模型input格式
程序码
def base64_to_binary_for_cv2(image_64_encoded):
# base64转numpy
img_base64_binary = image_64_encoded.encode("utf-8")
img_binary = base64.b64decode(img_base64_binary)
image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
cv2.IMREAD_COLOR)
# 图片预处理
image = process_img(image)
image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
image_for_tensorflow = np.asarray(image)
# 将接收的图片,储存到Log档
file1.write(image_64_encoded + '\n')
# Xception之input图片格式
image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
interpolation=cv2.INTER_CUBIC)
image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
image_for_Xception = image_for_Xception / 255
# InceptionResNetV2之input图片格式
image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
interpolation=cv2.INTER_CUBIC)
image_for_V2 = np.expand_dims(image_for_V2, axis=0)
image_for_V2 = image_for_V2 / 255
# DenseNet201之input图片格式
transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
contrast=(6, 6), saturation=(1, 1),
hue=(-0.1, 0.1)), ToTensor(),
Normalize((0.5,), (0.5,))])
image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
image_for_swa = transforms(image_for_swa)
return image_for_Xception,image_for_V2,image_for_swa
辨识手写中文字
计算3个模型之800字机率,并乘以加权分数。
将800字的加权机率进行加总,取得新的800字机率。
从新的800字机率中,取机率值最大的那个字,做为预测结果。
以阈值判断,该字是否属於800字内。若机率大於阈值,输出该字;反之,则输出isnull。
检查预测结果是否为字串。
程序码
def predict(image_for_Xception,image_for_V2,image_for_swa):
# InceptionResNetV2 predict的机率加权
# 机率向量
pred_V2 = model_V2.predict(image_for_V2)[0]
# 乘上权重的新机率向量
new_V2_prob = pred_V2 * weight_V2
# Xception predict的机率加权
# 机率向量
pred_Xception = model_Xception.predict(image_for_Xception)[0]
# 乘上权重的新机率向量
new_Xception_prob = pred_Xception * weight_Xception
# DenseNet201 predict的机率加权
img = image_for_swa.view(1, 3, 80, 80)
output = model_swa(img)
output = output.view(-1, 800)
output_prob = nn.functional.softmax(output, dim=1)
# 机率向量
output_prob_np = output_prob.cpu().detach().numpy()[0]
# 乘上权重取得新机率向量
new_swa_prob = output_prob_np * weight_swa
# 三个模型向量相加取得新的向量,判定手写中文字
new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
max_prob = np.max(new_prob)
pred_word = np.argmax(new_prob)
# 读取该手写中文字的阈值
judge = labels[pred_word]
mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
# 判断阈值
if max_prob < mean_prob:
prediction = "isnull"
else:
# 考虑加上SVM模型
new_prob_2dim = new_prob[np.newaxis,:]
# 丢入Rmodel预测是否为isnull,1为800字内,2为isnull
pred = model_R.predict(new_prob_2dim)
if pred == 2:
prediction = "isnull"
else:
final_answer = np.argmax(new_prob)
prediction = labels[final_answer]
# 检查预测结果是否为字串
if _check_datatype_to_string(prediction):
return prediction
API服务(inference 资料传输格式:json)
接收API用户之request。
取出json中image,并转换成图片格式。
产出server_uuid:做为回传时json内容之一。
记录错误log:供後续检查API服务error之用。
回传预测结果给主办方。
程序码
@app.route('/inference', methods=['POST'])
def inference():
# 接收用户request
data = request.get_json(force=True)
# 取image base64 encoded,并以cv2转换格式
image_64_encoded = data['image']
image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
# 产出server_uuid
t = datetime.datetime.now()
ts = str(int(t.utcnow().timestamp()))
server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
# 记录API错误log
try:
answer = predict(image_for_Xception,
image_for_V2,
image_for_swa)
except TypeError as type_error:
raise type_error
except Exception as e:
raise e
server_timestamp = time.time()
# 回传预测结果给用户
return jsonify({'esun_uuid': data['esun_uuid'],
'server_uuid': server_uuid,
'answer': answer,
'server_timestamp': server_timestamp})
if __name__ == "__main__":
arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
' [--port <port>] [--help]')
arg_parser.add_argument('-p', '--port', default=8080, help='port')
arg_parser.add_argument('-d', '--debug', default=True, help='debug')
options = arg_parser.parse_args()
app.run(host='0.0.0.0', port=options.port, debug=options.debug)
整个比赛下来收获很多,虽然成绩不是说特别好,但对於第一次参赛的我们,觉得这个经验非常值得。
结束了整个流程,明天来检讨哪些地方可以做改善,以及参考得奖队伍的做法。
<<: Day 29: 细节:资料库、Web、框架 (待改进中... )
>>: (特别篇)统计学的陷阱区,用资料绘制盒须—爬虫D3做成D3(下)
经过两天的简介,希望大家都对TypeScript(TS)有基本的了解。 今天呢要来讲解安装TS的开发...
运算式与运算子 运算式 透过运算子进行运算而得到指定的结果值 运算子的介绍 这边会列出几个简单算是常...
打开视野 藉由这次铁人赛我看到许多不同类型的文章,也看到很多人在前端技术上努力(铁人赛还有很多主题,...
今天我们要来看到第一个画面了! 30天内所有写的扣笔者都会放在这个Git专案(https://git...
前言 昨天在建立环境的时候发现有很多相容性的问题,因此今天我想说这几天先来学习一下tensorflo...