【第25天】部署API服务-Python Flask

摘要

  1. 导入套件
  2. 模型初始化资料
  3. API初始化
  4. server_uuid
  5. 转换图片格式
  6. 模型辨识手写中文字
  7. 检查预测结果是否为字串
  8. API服务
  9. 启用API服务

内容

  1. 导入套件

    import base64
    import datetime
    import hashlib
    import time
    from argparse import ArgumentParser
    import multiprocessing
    import cv2
    import numpy as np
    from flask import Flask
    from flask import jsonify
    from flask import request
    from img_gray import process_img
    from PIL import Image
    import torch
    from torch import nn
    from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
    import torchvision.models as models
    import pandas as pd
    from R_model_load import Model
    from tensorflow.keras.preprocessing import image
    from tensorflow.keras.models import load_model
    import numpy as np
    import os
    from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
    
  2. 模型初始化资料

    2.1 资料内容

    • 队长Email
    • uuid加密
    • CPU运算:因GCP免费试用,开启的VM无GPU,故以CPU运算。
    • 4个Model初始化:3个影像辨识模型+1个SVM模型(用以进一步判断isnull)。
    • 接收图片的Log与官方800字清单
    • 模型组合之权重与阈值表

    2.2 程序码

    app = Flask(__name__)
    
    # 队长email
    CAPTAIN_EMAIL = '[email protected]'
    
    # uuid加密
    SALT = '1688'
    
    # CPU运算(关闭GPU)
    os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    
    # 4个Model初始化
    # Xception
    model_Xception = None
    # InceptionResNetV2
    model_V2 = None
    # Densenet201
    model_swa = None
    # R_SVM_model
    model_R = None
    
    # 接收的图片Log档
    file1 = open('./pic_base64.txt', 'a')
    # 官方800字清单
    words_path = r'./800_words.txt'
    file2 = open(words_path, 'rt', encoding='Big5')
    labels = list(file2.read())
    
    # 模型组合之权重与阈值表
    # 载入表格
    weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
    # DenseNet权重
    weight_swa = weight_df['wei_ex6'].values
    # InceptionResNetV2权重
    weight_V2 = weight_df['wei_ex5'].values
    # Xception权重
    weight_Xception = weight_df['wei_3'].values
    
  3. API初始化

    3.1 before_first_request:在处理第一个request前,先执行API初始化,用以载入模型。

    3.2 程序码

    @app.before_first_request
    def init():
        # Xception
        global model_Xception
        model_Xception = load_model('./Xception_retrained_v2.h5')
    
        # InceptionResNetV2
        global model_V2
        model_V2 = load_model('./InceptionResNetV2.h5')
    
        # DenseNet201
        global model_swa
        model_densenet = models.densenet201(num_classes=800)
        model_path = './swa_densenet201.pth'
        model_fang = model_densenet
        model_swa = AveragedModel(model_fang)
        model_swa.eval()
        model_swa.load_state_dict(torch.load(model_path,
                                  map_location=torch.device('cpu')))
    
        # SVM模型
        global model_R
        MODEL_PATH = "./model_svm_v3"
        model_R = Model().load(MODEL_PATH)
        print('====================API初始化完成init====================')
    
  4. 产出server_uuid

    def generate_server_uuid(input_string):
        s = hashlib.sha256()
        data = (input_string + SALT).encode("utf-8")
        s.update(data)
        server_uuid = s.hexdigest()
        return server_uuid
    
  5. 检查预测结果是否为字串:供後续输出预测结果之前,判定资料型态。

    def _check_datatype_to_string(prediction):
        if isinstance(prediction, str):
            return True
        raise TypeError('Prediction is not in string type.')
    
  6. 将接收到的图片转换格式

    6.1 流程

    • 将base64编码转换成numpy格式。
    • 图片预处理:将图片转换成灰阶。
    • 将接收的图片,储存到Log档:纪录比赛图片样本,供後续改善模型之用。
    • 将图片转换成模型input格式

    6.2 程序码

    def base64_to_binary_for_cv2(image_64_encoded):
        # base64转numpy
        img_base64_binary = image_64_encoded.encode("utf-8")
        img_binary = base64.b64decode(img_base64_binary)
        image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
                             cv2.IMREAD_COLOR)
    
        # 图片预处理
        image = process_img(image)
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
        image_for_tensorflow = np.asarray(image)
    
        # 将接收的图片,储存到Log档
        file1.write(image_64_encoded + '\n')
    
        # Xception之input图片格式
        image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
                                        interpolation=cv2.INTER_CUBIC)
        image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
        image_for_Xception = image_for_Xception / 255
    
        # InceptionResNetV2之input图片格式
        image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
                                  interpolation=cv2.INTER_CUBIC)
        image_for_V2 = np.expand_dims(image_for_V2, axis=0)
        image_for_V2 = image_for_V2 / 255
    
        # DenseNet201之input图片格式
        transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
                              contrast=(6, 6), saturation=(1, 1),
                              hue=(-0.1, 0.1)), ToTensor(),
                              Normalize((0.5,), (0.5,))])
        image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
        image_for_swa = transforms(image_for_swa)
    
        return image_for_Xception,image_for_V2,image_for_swa
    
  7. 模型辨识手写中文字

    7.1 流程

    • 计算3个模型之800字机率,并乘以加权分数。
    • 将800字的加权机率进行加总,取得新的800字机率。
    • 从新的800字机率中,取机率值最大的那个字,做为预测结果。
    • 以阈值判断,该字是否属於800字内。若机率大於阈值,输出该字;反之,则输出isnull。
    • 检查预测结果是否为字串。

    7.2 程序码

    def predict(image_for_Xception,image_for_V2,image_for_swa):
        # InceptionResNetV2 predict的机率加权
        # 机率向量
        pred_V2 = model_V2.predict(image_for_V2)[0]
        # 乘上权重的新机率向量
        new_V2_prob = pred_V2 * weight_V2
    
        # Xception predict的机率加权
        # 机率向量
        pred_Xception = model_Xception.predict(image_for_Xception)[0] 
        # 乘上权重的新机率向量
        new_Xception_prob = pred_Xception * weight_Xception 
    
        # DenseNet201 predict的机率加权
        img = image_for_swa.view(1, 3, 80, 80)
        output = model_swa(img)
        output = output.view(-1, 800)
        output_prob = nn.functional.softmax(output, dim=1)
        # 机率向量
        output_prob_np = output_prob.cpu().detach().numpy()[0]
        # 乘上权重取得新机率向量
        new_swa_prob = output_prob_np * weight_swa 
    
        # 三个模型向量相加取得新的向量,判定手写中文字
        new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
        max_prob = np.max(new_prob)
        pred_word = np.argmax(new_prob)
    
        # 读取该手写中文字的阈值
        judge = labels[pred_word]
        mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
    
        # 判断阈值
        if max_prob < mean_prob:
            prediction = "isnull"
        else:
            # 考虑加上SVM模型
            new_prob_2dim = new_prob[np.newaxis,:]
            # 丢入Rmodel预测是否为isnull,1为800字内,2为isnull
            pred = model_R.predict(new_prob_2dim)
            if pred == 2:
                prediction = "isnull"
            else:
                final_answer = np.argmax(new_prob)
                prediction = labels[final_answer]
    
        # 检查预测结果是否为字串
        if _check_datatype_to_string(prediction):
            return prediction
    
  8. API服务(inference 资料传输格式:json)

    8.1 流程

    • 接收API用户之request。
    • 取出json中image,并转换成图片格式。
    • 产出server_uuid:做为回传时json内容之一。
    • 记录错误log:供後续检查API服务error之用。
    • 回传预测结果给用户

    8.2 程序码

    @app.route('/inference', methods=['POST'])
    def inference():
        # 接收用户request
        data = request.get_json(force=True)
    
        # 取image base64 encoded,并以cv2转换格式
        image_64_encoded = data['image']
        image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
    
        # 产出server_uuid
        t = datetime.datetime.now()
        ts = str(int(t.utcnow().timestamp()))
        server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
    
        # 记录API错误log
        try:
            answer = predict(image_for_Xception,
                             image_for_V2,
                             image_for_swa)
        except TypeError as type_error:
            raise type_error
        except Exception as e:
            raise e
        server_timestamp = time.time()
    
        # 回传预测结果给用户
        return jsonify({'esun_uuid': data['esun_uuid'],
                        'server_uuid': server_uuid,
                        'answer': answer,
                        'server_timestamp': server_timestamp})
    
    if __name__ == "__main__":
    
        arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
                                   ' [--port <port>] [--help]')
        arg_parser.add_argument('-p', '--port', default=8080, help='port')
        arg_parser.add_argument('-d', '--debug', default=True, help='debug')
        options = arg_parser.parse_args()
    
        app.run(host='0.0.0.0', port=options.port, debug=options.debug)
    
  9. 启用API服务

    9.1 如何启用API服务

    • 到GCP启用VM
    • 开启cmd与ssh连线
    • 前往目标资料夹
    • 输入指令:python3 api_1.py

    9.2 成功启用API服务(如下图)


小结

  1. 完成最後一关「部署API服务」後,玉山竞赛流程就到此暂时告一段落。
  2. 监於此次竞赛仍有进步空间,後续几天将透过实作,和大家分享一些可能提升模型辨识效果的方法。

让我们继续看下去...


<<:  [Day25] 求值策略

>>:  [经典回顾]知名通讯软件过度存取用户资讯事件

[13th][Day18] Unmarshal

有句话说,没用过 unmarshal 就等於没写过 go func Unmarshal(data [...

【少女人妻的30天Elastic】Day 29 : App Search_API 介绍与应用_Curations

Aloha!我是少女人妻 Uerica!这个周末朋友要求婚了~朋友前阵子喝了一点然後问我婚姻的感觉...

BPM懒人包 让你一次搞懂BPM的大小事

为了要了解企业流程管理(BPM),很多人上网搜寻到的文章,常常都有些八股,或是看不到想要了解的部分,...

[JSON] python-jose 安装与范例

因缘际会需要串某个 JSON API ,然後跟加密这方面实在是不熟,而对方给的范例又不是 Pytho...

Day27-好用的网页服务器-nginx(三)

前言 昨天的文章讲完前端 Nginx 的写法後,今天就要来进入後端的写法啦!在昨天的小结提到後端的写...