DAY29:开启API服务(完赛)

部署及开启API服务-flask

  1. 导入套件
import base64
import datetime
import hashlib
import time
from argparse import ArgumentParser
import multiprocessing
import cv2
import numpy as np
from flask import Flask
from flask import jsonify
from flask import request
from img_gray import process_img
from PIL import Image
import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
import torchvision.models as models
import pandas as pd
from R_model_load import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import numpy as np
import os
from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
  1. 初始化

    • 队长Email
    • uuid加密
    • CPU运算:因GCP免费试用,开启的VM无GPU,故以CPU运算。
    • 4个Model初始化:3个影像辨识模型+1个SVM模型(用以判断isnull)。
    • 接收图片的Log与官方800字清单
    • 模型组合之权重与阈值表

    程序码

    app = Flask(__name__)
    
    # 队长email
    CAPTAIN_EMAIL = '[email protected]'
    
    # uuid加密
    SALT = '1688'
    
    # CPU运算(关闭GPU)
    os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    
    # 4个Model初始化
    # Xception
    model_Xception = None
    # InceptionResNetV2
    model_V2 = None
    # Densenet201
    model_swa = None
    # R_SVM_model
    model_R = None
    
    # 接收的图片Log档
    file1 = open('./pic_base64.txt', 'a')
    # 官方800字清单
    words_path = r'./800_words.txt'
    file2 = open(words_path, 'rt', encoding='Big5')
    labels = list(file2.read())
    
    # 模型组合之权重与阈值表
    # 载入表格
    weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
    # DenseNet权重
    weight_swa = weight_df['wei_ex6'].values
    # InceptionResNetV2权重
    weight_V2 = weight_df['wei_ex5'].values
    # Xception权重
    weight_Xception = weight_df['wei_3'].values   
    
  2. API初始化

    • before_first_request:在处理第一个request前,先执行API初始化,用以载入模型。使用此装饰器的原因:

      • 一开始没有用多线程,我们的模型又较大,在接收图片时会处理较慢,导致无法一次接收处理多张图片,让我们比赛有一天只回传不到一半的答案。
      • 後来使用多线程发现,tensorflow的模型会无法读取到,後来找到before_first_request这个解决方式。
    • 程序码

      @app.before_first_request
      def init():
         # Xception
         global model_Xception
         model_Xception = load_model('./Xception_retrained_v2.h5')
      
         # InceptionResNetV2
         global model_V2
         model_V2 = load_model('./InceptionResNetV2.h5')
      
         # DenseNet201
         global model_swa
         model_densenet = models.densenet201(num_classes=800)
         model_path = './swa_densenet201.pth'
         model_fang = model_densenet
         model_swa = AveragedModel(model_fang)
         model_swa.eval()
         model_swa.load_state_dict(torch.load(model_path,
                                   map_location=torch.device('cpu')))
      
         # SVM模型
         global model_R
         MODEL_PATH = "./model_svm_v3"
         model_R = Model().load(MODEL_PATH)
         print('====================API初始化完成init====================')
      
  3. 产出server_uuid

def generate_server_uuid(input_string):
    s = hashlib.sha256()
    data = (input_string + SALT).encode("utf-8")
    s.update(data)
    server_uuid = s.hexdigest()
    return server_uuid
  1. 检查预测结果是否为字串:供後续输出预测结果之前,判定资料型态。
def _check_datatype_to_string(prediction):
    if isinstance(prediction, str):
        return True
    raise TypeError('Prediction is not in string type.')
  1. 将接收到的图片转换格式

    • 将base64编码转换成numpy格式。

    • 将图片去杂讯并转换成灰阶。

    • 纪录比赛图片样本,存入log档,供後续改善模型之用。

    • 将图片转换成模型input格式

    • 程序码

    def base64_to_binary_for_cv2(image_64_encoded):
        # base64转numpy
        img_base64_binary = image_64_encoded.encode("utf-8")
        img_binary = base64.b64decode(img_base64_binary)
        image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
                             cv2.IMREAD_COLOR)
    
        # 图片预处理
        image = process_img(image)
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
        image_for_tensorflow = np.asarray(image)
    
        # 将接收的图片,储存到Log档
        file1.write(image_64_encoded + '\n')
    
        # Xception之input图片格式
        image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
                                        interpolation=cv2.INTER_CUBIC)
        image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
        image_for_Xception = image_for_Xception / 255
    
        # InceptionResNetV2之input图片格式
        image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
                                  interpolation=cv2.INTER_CUBIC)
        image_for_V2 = np.expand_dims(image_for_V2, axis=0)
        image_for_V2 = image_for_V2 / 255
    
        # DenseNet201之input图片格式
        transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
                              contrast=(6, 6), saturation=(1, 1),
                              hue=(-0.1, 0.1)), ToTensor(),
                              Normalize((0.5,), (0.5,))])
        image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
        image_for_swa = transforms(image_for_swa)
    
        return image_for_Xception,image_for_V2,image_for_swa     
    
  2. 辨识手写中文字

    • 计算3个模型之800字机率,并乘以加权分数。

    • 将800字的加权机率进行加总,取得新的800字机率。

    • 从新的800字机率中,取机率值最大的那个字,做为预测结果。

    • 以阈值判断,该字是否属於800字内。若机率大於阈值,输出该字;反之,则输出isnull。

    • 检查预测结果是否为字串。

    • 程序码

    def predict(image_for_Xception,image_for_V2,image_for_swa):
        # InceptionResNetV2 predict的机率加权
        # 机率向量
        pred_V2 = model_V2.predict(image_for_V2)[0]
        # 乘上权重的新机率向量
        new_V2_prob = pred_V2 * weight_V2
    
        # Xception predict的机率加权
        # 机率向量
        pred_Xception = model_Xception.predict(image_for_Xception)[0] 
        # 乘上权重的新机率向量
        new_Xception_prob = pred_Xception * weight_Xception 
    
        # DenseNet201 predict的机率加权
        img = image_for_swa.view(1, 3, 80, 80)
        output = model_swa(img)
        output = output.view(-1, 800)
        output_prob = nn.functional.softmax(output, dim=1)
        # 机率向量
        output_prob_np = output_prob.cpu().detach().numpy()[0]
        # 乘上权重取得新机率向量
        new_swa_prob = output_prob_np * weight_swa 
    
        # 三个模型向量相加取得新的向量,判定手写中文字
        new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
        max_prob = np.max(new_prob)
        pred_word = np.argmax(new_prob)
    
        # 读取该手写中文字的阈值
        judge = labels[pred_word]
        mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
    
        # 判断阈值
        if max_prob < mean_prob:
            prediction = "isnull"
        else:
            # 考虑加上SVM模型
            new_prob_2dim = new_prob[np.newaxis,:]
            # 丢入Rmodel预测是否为isnull,1为800字内,2为isnull
            pred = model_R.predict(new_prob_2dim)
            if pred == 2:
                prediction = "isnull"
            else:
                final_answer = np.argmax(new_prob)
                prediction = labels[final_answer]
    
        # 检查预测结果是否为字串
        if _check_datatype_to_string(prediction):
            return prediction
    
  3. API服务(inference 资料传输格式:json)

    • 接收API用户之request。

    • 取出json中image,并转换成图片格式。

    • 产出server_uuid:做为回传时json内容之一。

    • 记录错误log:供後续检查API服务error之用。

    • 回传预测结果给主办方。

    • 程序码

    @app.route('/inference', methods=['POST'])
    def inference():
        # 接收用户request
        data = request.get_json(force=True)
    
        # 取image base64 encoded,并以cv2转换格式
        image_64_encoded = data['image']
        image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
    
        # 产出server_uuid
        t = datetime.datetime.now()
        ts = str(int(t.utcnow().timestamp()))
        server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
    
        # 记录API错误log
        try:
            answer = predict(image_for_Xception,
                             image_for_V2,
                             image_for_swa)
        except TypeError as type_error:
            raise type_error
        except Exception as e:
            raise e
        server_timestamp = time.time()
    
        # 回传预测结果给用户
        return jsonify({'esun_uuid': data['esun_uuid'],
                        'server_uuid': server_uuid,
                        'answer': answer,
                        'server_timestamp': server_timestamp})
    
    if __name__ == "__main__":
    
        arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
                                   ' [--port <port>] [--help]')
        arg_parser.add_argument('-p', '--port', default=8080, help='port')
        arg_parser.add_argument('-d', '--debug', default=True, help='debug')
        options = arg_parser.parse_args()
    
        app.run(host='0.0.0.0', port=options.port, debug=options.debug)
    

小结

  • 整个比赛下来收获很多,虽然成绩不是说特别好,但对於第一次参赛的我们,觉得这个经验非常值得。

  • 结束了整个流程,明天来检讨哪些地方可以做改善,以及参考得奖队伍的做法。


<<:  Day 29: 细节:资料库、Web、框架 (待改进中... )

>>:  (特别篇)统计学的陷阱区,用资料绘制盒须—爬虫D3做成D3(下)

Day3-TypeScript(TS)安装开发环境

经过两天的简介,希望大家都对TypeScript(TS)有基本的了解。 今天呢要来讲解安装TS的开发...

[ Day 3 ] - 运算式与运算子

运算式与运算子 运算式 透过运算子进行运算而得到指定的结果值 运算子的介绍 这边会列出几个简单算是常...

这些日子我学到的JavaScript:Day29- 尾声

打开视野 藉由这次铁人赛我看到许多不同类型的文章,也看到很多人在前端技术上努力(铁人赛还有很多主题,...

Day2. Hello Matter.js World!

今天我们要来看到第一个画面了! 30天内所有写的扣笔者都会放在这个Git专案(https://git...

Day 16 self-attention的实作准备(二) 设定tensorflow和keras的环境

前言 昨天在建立环境的时候发现有很多相容性的问题,因此今天我想说这几天先来学习一下tensorflo...