当前位置: 首页 > 开发杂谈 >

DAY29:开启API服务(完赛)

部署及开启API服务-flask

  1. 导入套件
import base64
import datetime
import hashlib
import time
from argparse import ArgumentParser
import multiprocessing
import cv2
import numpy as np
from flask import Flask
from flask import jsonify
from flask import request
from img_gray import process_img
from PIL import Image
import torch
from torch import nn
from torchvision.transforms import Compose, ToTensor,Resize,ColorJitter,Normalize
import torchvision.models as models
import pandas as pd
from R_model_load import Model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.models import load_model
import numpy as np
import os
from torch.optim.swa_utils import AveragedModel, update_bn, SWALR
  1. 初始化

    • 队长Email
    • uuid加密
    • CPU运算:因GCP免费试用,开启的VM无GPU,故以CPU运算。
    • 4个Model初始化:3个影像辨识模型+1个SVM模型(用以判断isnull)。
    • 接收图片的Log与官方800字清单
    • 模型组合之权重与阈值表

    程序码

    app = Flask(__name__)
    
    # 队长email
    CAPTAIN_EMAIL = 'XXXXXX@gmail.com'
    
    # uuid加密
    SALT = '1688'
    
    # CPU运算(关闭GPU)
    os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    
    # 4个Model初始化
    # Xception
    model_Xception = None
    # InceptionResNetV2
    model_V2 = None
    # Densenet201
    model_swa = None
    # R_SVM_model
    model_R = None
    
    # 接收的图片Log档
    file1 = open('./pic_base64.txt', 'a')
    # 官方800字清单
    words_path = r'./800_words.txt'
    file2 = open(words_path, 'rt', encoding='Big5')
    labels = list(file2.read())
    
    # 模型组合之权重与阈值表
    # 载入表格
    weight_df = pd.read_csv("./model_weight_final.csv", encoding="Big5")
    # DenseNet权重
    weight_swa = weight_df['wei_ex6'].values
    # InceptionResNetV2权重
    weight_V2 = weight_df['wei_ex5'].values
    # Xception权重
    weight_Xception = weight_df['wei_3'].values   
    
  2. API初始化

    • before_first_request:在处理第一个request前,先执行API初始化,用以载入模型。使用此装饰器的原因:

      • 一开始没有用多线程,我们的模型又较大,在接收图片时会处理较慢,导致无法一次接收处理多张图片,让我们比赛有一天只回传不到一半的答案。
      • 後来使用多线程发现,tensorflow的模型会无法读取到,後来找到before_first_request这个解决方式。
    • 程序码

      @app.before_first_request
      def init():
         # Xception
         global model_Xception
         model_Xception = load_model('./Xception_retrained_v2.h5')
      
         # InceptionResNetV2
         global model_V2
         model_V2 = load_model('./InceptionResNetV2.h5')
      
         # DenseNet201
         global model_swa
         model_densenet = models.densenet201(num_classes=800)
         model_path = './swa_densenet201.pth'
         model_fang = model_densenet
         model_swa = AveragedModel(model_fang)
         model_swa.eval()
         model_swa.load_state_dict(torch.load(model_path,
                                   map_location=torch.device('cpu')))
      
         # SVM模型
         global model_R
         MODEL_PATH = "./model_svm_v3"
         model_R = Model().load(MODEL_PATH)
         print('====================API初始化完成init====================')
      
  3. 产出server_uuid

def generate_server_uuid(input_string):
    s = hashlib.sha256()
    data = (input_string + SALT).encode("utf-8")
    s.update(data)
    server_uuid = s.hexdigest()
    return server_uuid
  1. 检查预测结果是否为字串:供後续输出预测结果之前,判定资料型态。
def _check_datatype_to_string(prediction):
    if isinstance(prediction, str):
        return True
    raise TypeError('Prediction is not in string type.')
  1. 将接收到的图片转换格式

    • 将base64编码转换成numpy格式。

    • 将图片去杂讯并转换成灰阶。

    • 纪录比赛图片样本,存入log档,供後续改善模型之用。

    • 将图片转换成模型input格式

    • 程序码

    def base64_to_binary_for_cv2(image_64_encoded):
        # base64转numpy
        img_base64_binary = image_64_encoded.encode("utf-8")
        img_binary = base64.b64decode(img_base64_binary)
        image = cv2.imdecode(np.frombuffer(img_binary, np.uint8),
                             cv2.IMREAD_COLOR)
    
        # 图片预处理
        image = process_img(image)
        image = Image.fromarray(cv2.cvtColor(image,cv2.COLOR_GRAY2RGB))
        image_for_tensorflow = np.asarray(image)
    
        # 将接收的图片,储存到Log档
        file1.write(image_64_encoded + '\n')
    
        # Xception之input图片格式
        image_for_Xception = cv2.resize(image_for_tensorflow, (80,80),
                                        interpolation=cv2.INTER_CUBIC)
        image_for_Xception = np.expand_dims(image_for_Xception, axis=0)
        image_for_Xception = image_for_Xception / 255
    
        # InceptionResNetV2之input图片格式
        image_for_V2 = cv2.resize(image_for_tensorflow, (150 , 150),
                                  interpolation=cv2.INTER_CUBIC)
        image_for_V2 = np.expand_dims(image_for_V2, axis=0)
        image_for_V2 = image_for_V2 / 255
    
        # DenseNet201之input图片格式
        transforms = Compose([ColorJitter(brightness=(1.5, 1.5),
                              contrast=(6, 6), saturation=(1, 1),
                              hue=(-0.1, 0.1)), ToTensor(),
                              Normalize((0.5,), (0.5,))])
        image_for_swa = image.resize((80, 80), Image.ANTIALIAS)
        image_for_swa = transforms(image_for_swa)
    
        return image_for_Xception,image_for_V2,image_for_swa     
    
  2. 辨识手写中文字

    • 计算3个模型之800字机率,并乘以加权分数。

    • 将800字的加权机率进行加总,取得新的800字机率。

    • 从新的800字机率中,取机率值最大的那个字,做为预测结果。

    • 以阈值判断,该字是否属於800字内。若机率大於阈值,输出该字;反之,则输出isnull。

    • 检查预测结果是否为字串。

    • 程序码

    def predict(image_for_Xception,image_for_V2,image_for_swa):
        # InceptionResNetV2 predict的机率加权
        # 机率向量
        pred_V2 = model_V2.predict(image_for_V2)[0]
        # 乘上权重的新机率向量
        new_V2_prob = pred_V2 * weight_V2
    
        # Xception predict的机率加权
        # 机率向量
        pred_Xception = model_Xception.predict(image_for_Xception)[0] 
        # 乘上权重的新机率向量
        new_Xception_prob = pred_Xception * weight_Xception 
    
        # DenseNet201 predict的机率加权
        img = image_for_swa.view(1, 3, 80, 80)
        output = model_swa(img)
        output = output.view(-1, 800)
        output_prob = nn.functional.softmax(output, dim=1)
        # 机率向量
        output_prob_np = output_prob.cpu().detach().numpy()[0]
        # 乘上权重取得新机率向量
        new_swa_prob = output_prob_np * weight_swa 
    
        # 三个模型向量相加取得新的向量,判定手写中文字
        new_prob = new_swa_prob + new_Xception_prob + new_V2_prob
        max_prob = np.max(new_prob)
        pred_word = np.argmax(new_prob)
    
        # 读取该手写中文字的阈值
        judge = labels[pred_word]
        mean_prob = weight_df[weight_df["word"] == judge]["mean_prob_new"].values
    
        # 判断阈值
        if max_prob < mean_prob:
            prediction = "isnull"
        else:
            # 考虑加上SVM模型
            new_prob_2dim = new_prob[np.newaxis,:]
            # 丢入Rmodel预测是否为isnull,1为800字内,2为isnull
            pred = model_R.predict(new_prob_2dim)
            if pred == 2:
                prediction = "isnull"
            else:
                final_answer = np.argmax(new_prob)
                prediction = labels[final_answer]
    
        # 检查预测结果是否为字串
        if _check_datatype_to_string(prediction):
            return prediction
    
  3. API服务(inference 资料传输格式:json)

    • 接收API用户之request。

    • 取出json中image,并转换成图片格式。

    • 产出server_uuid:做为回传时json内容之一。

    • 记录错误log:供後续检查API服务error之用。

    • 回传预测结果给主办方。

    • 程序码

    @app.route('/inference', methods=['POST'])
    def inference():
        # 接收用户request
        data = request.get_json(force=True)
    
        # 取image base64 encoded,并以cv2转换格式
        image_64_encoded = data['image']
        image_for_Xception,image_for_V2,image_for_swa = base64_to_binary_for_cv2(image_64_encoded)
    
        # 产出server_uuid
        t = datetime.datetime.now()
        ts = str(int(t.utcnow().timestamp()))
        server_uuid = generate_server_uuid(CAPTAIN_EMAIL + ts)
    
        # 记录API错误log
        try:
            answer = predict(image_for_Xception,
                             image_for_V2,
                             image_for_swa)
        except TypeError as type_error:
            raise type_error
        except Exception as e:
            raise e
        server_timestamp = time.time()
    
        # 回传预测结果给用户
        return jsonify({'esun_uuid': data['esun_uuid'],
                        'server_uuid': server_uuid,
                        'answer': answer,
                        'server_timestamp': server_timestamp})
    
    if __name__ == "__main__":
    
        arg_parser = ArgumentParser(usage='Usage: python ' + __file__ +
                                   ' [--port <port>] [--help]')
        arg_parser.add_argument('-p', '--port', default=8080, help='port')
        arg_parser.add_argument('-d', '--debug', default=True, help='debug')
        options = arg_parser.parse_args()
    
        app.run(host='0.0.0.0', port=options.port, debug=options.debug)
    

小结

  • 整个比赛下来收获很多,虽然成绩不是说特别好,但对於第一次参赛的我们,觉得这个经验非常值得。

  • 结束了整个流程,明天来检讨哪些地方可以做改善,以及参考得奖队伍的做法。


相关文章:

  • 亚马逊卖家选择后台关键词的一些注意事项
  • 30天学会C语言: Day 14-全部包轨!
  • Day 28:顺手挖洞给 i 跳-vue-i18n
  • 怎么借助广告转化拉进与买家的距离?
  • 赌场也有打烊的时候 - 盘後回测
  • [30天 Vue学好学满 DAY11] v-on
  • Day[-2] 今天我想来点Kibana的Data Table
  • 终章:TeamCity 进阶学习路径
  • 外贸邮件营销需要注意的五条规则
  • 关于出口退税账务处理的实用案例
  • Day24 ( 高级 ) 骇客任务背景特效
  • Day 25 XIB跳转页面以及UIAlertController的练习(3/3)
  • 虾皮串接实作笔记-串接 API 虾皮订单
  • 如何更好的运用视频前的广告时间?
  • [Day03] 培养人脉,从正向思考开始
  • 数字人民币是什么?什么是数字人民币
  • 香港电话卡怎么在内地使用: CSL Hello/Three/CMHK/Smartone电话卡内地使用方法
  • 寻找印度市场伙伴
  • 2021年10个全球电子商务趋势[信息图] ,所有电商人员都该了解一下
  • WordPress禁用古腾堡编辑器全屏模式
  • 海外适合游戏投放的渠道有哪些?
  • Vultr促销码和2020年最新优惠:Vultr注册教程和使用方法
  • SiteGround主机评测和推荐
  • Google SEO优化排名的技巧:做好这20件事情谷歌排名必定上去
  • 如何使用Hostinger的邮箱服务,Hostinger免费企业邮箱设置教程
  • DNS是什么?DNS有什么用?为什么要用DNS解析域名
  • 虚拟信用卡是什么?虚拟信用卡安全吗?怎么用?怎么申请教程
  • vultr.com怎么申请退款教程和方法
  • 专业提供东南亚-越南线上支付通道
  • WordPress怎么建多语言网站:Polylang怎么用?如何用Polylang建多语言网站