[神经机器翻译理论与实作] 从头建立英中文翻译器 (VI)

前言

今天接着完成翻译任务实作的第二阶段-模型推论。

翻译器建立实作

重新评估翻译模型

上次由於输入特徵 X 以及原始句对并非一一对应,造成了 BLEU 分数低下的结果。

分别针对训练资料以及测试资料建立原始句对以及其特徵 X_train 、标签 y_train

# Generate training data
train_dataset = np.load("data/eng-cn/train_data_io.npz")
seq_pairs_train, X_train, y_train, src_vocab_size = create_data(train_dataset)

# Generate test data
test_dataset = np.load("data/eng-cn/test_data_io.npz")
seq_pairs_test, X_test, y_test, _ = create_data(test_dataset)

我们重新在训练资料集上评估模型:

# evaluate model on training dataset
eval_NMT(seq2seq, X_train, seq_pairs_train, reverse_tgt_vocab_dict)

列出前五句的英文原句、中文原句以及中文译句:
https://ithelp.ithome.com.tw/upload/images/20211011/20140744sTMX8spio7.jpg
匹配各个 n-grams 而计算出的 BLEU 分数:
https://ithelp.ithome.com.tw/upload/images/20211011/20140744RGhDYPDyoU.jpg

接着轮到在训练资料集上重新评估模型:

# evaluate model on test dataset
eval_NMT(seq2seq, X_test, seq_pairs_test, reverse_tgt_vocab_dict)

列出前五句的英文原句、中文原句以及中文译句:
https://ithelp.ithome.com.tw/upload/images/20211011/20140744QGMG0YwtXi.jpg
匹配各个 n-grams 而计算出的 BLEU 分数:
https://ithelp.ithome.com.tw/upload/images/20211011/20140744thN6GpvURC.jpg

可以见到,当愈多元的 n-grams 列入准确度的计算,得到的 BLEU 就会愈低,这是必然的结果。而此模型不论是在训练资料集或测试资料集上评估而得的 BLEU 分数皆几乎都在 0.7 以上,准确度令人满意。

模型推论(Inference Phase)

首先将预训练好的 seq2seq 模型载入,取出每层神经元以便沿用训练阶段收敛的参数矩阵数值:

from tensorflow.keras.models import load_model

eng_cn_seq2seq = load_model("models/eng-cn_translator_v3.h5")

layers_list = eng_cn_seq2seq.layers
layer_names = [layer.name for layer in layers_list]
weights_list = eng_cn_seq2seq.get_weights()

在训练阶段时,我们将编码器和解码器的输入当作特徵一并传入神经网络当中。然而在实际翻译的情境中,我们并不会预先知道目标语言的文句,也就是解码器的输入。因此在推论阶段( inference loop )时,我们只提供来源语言文句,传入编码器後,再将资讯透过解码器传出。
习惯上我们会先提供解码器的初始输入为指示句首的符号<sos>,透过 token by token 的预测,更新解码器的输入,直到出现指示句末的符号<eos>出现,或是已经超过目标语言最大句长,则停止预测,完成整个句子的翻译。

首先建立可独立预测的编码器模型:

from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model

layers_list = eng_cn_seq2seq.layers
layer_names = [layer.name for layer in layers_list]
weights_list = eng_cn_seq2seq.get_weights()

# Build an Encoder Model in an inference mode
enc_inputs = training_seq2seq.input[0]
enc_outputs_1, enc_h1, enc_c1 = layers_list[4].output
enc_outputs_2, enc_h2, enc_c2 = layers_list[6].output
enc_states = [enc_h1, enc_c1, enc_h2, enc_h2]

encoder_model = Model(enc_inputs, [enc_states, enc_outputs_2], name = "encoder_model")

接下来则是承接编码器输出资讯的解码器模型:

# define all input tensors
dec_state_input_h1 = Input(shape = (latent_dim, ))
dec_state_input_c1 = Input(shape = (latent_dim, ))
dec_state_input_h2 = Input(shape = (latent_dim, ))
dec_state_input_c2 = Input(shape = (latent_dim, ))
src_max_seq_length = 38
enc_outputs_final = Input(shape = (src_max_seq_length, latent_dim))
dec_states_inputs = [dec_state_input_h1, dec_state_input_c1, dec_state_input_h2, dec_state_input_c2]


# Build a Decoder Model in an inference mode
embed_vec = layers_list[3](dec_inputs) # embedding layer
dec_outputs_1, dec_h1, dec_c1 = layers_list[5](embed_vec, initial_state = dec_states_inputs[:2]) # 1st LSTM layer
dec_outputs_2, dec_h2, dec_c2 = layers_list[7](dec_outputs_1, initial_state = dec_states_inputs[2:])   # 2nd LSTM layer
attention_scores = layers_list[8]([dec_outputs_2, enc_outputs_final]) # dot: attention function to get attention scores
attention_weights = layers_list[9](attention_scores)
context_vec = layers_list[10]([attention_weights, enc_outputs_final])
ht_context_vec = layers_list[11]([context_vec, dec_outputs_2])
attention_vec = layers_list[12](ht_context_vec)
logits = layers_list[13](attention_vec)
dec_outputs_final = layers_list[14](logits)

dec_states = [dec_h1, dec_c1, dec_h2, dec_c2]

# decoder_model = Model([dec_inputs] + dec_states_inputs + [enc_outputs_final], [dec_outputs_final] + dec_states, name = "decoder_model_inference_10_epochs")
decoder_model = Model([dec_inputs] + dec_states_inputs + [enc_outputs_final], [dec_outputs_final] + dec_states + [attention_weights], name = "decoder_model")

接下来则是利用编码器承接输入文句,传递给解码器,完成整个句子的翻译:

def decode_sequence(input_sentence):
    # visualise association between tokens using the alignment matrix
    attention_plot = np.zeros(shape = (tgt_max_seq_length, src_max_seq_length))

    input_sentences = [input_sentence]
    input_seq = encode_input_sequences(eng_tokeniser, src_max_seq_length, input_sentences)
    enc_states, enc_outputs_2 = encoder_model.predict(input_seq)
    # enc_states = [enc_h1, enc_c1, enc_h2, enc_c2]
    dec_states = enc_states
    # generate empty target sequence of length 1
    tgt_seq = np.zeros(shape = (1, 1))
    tgt_seq[0, 0] = tgt_vocab_dict["<sos>"]

    decoded_sentence = ""
    stop_cond = False
    while not stop_cond:
        output_tokens, new_dec_h1, new_dec_c1, new_dec_h2, new_dec_c2 = decoder_model.predict([tgt_seq] + dec_states + [enc_outputs_2])
        # Sample a token: Label encode the current one-hot encoded output
        # Greedy search: each time find the most likely token (last position in the sequence)
        sampled_word_idx = np.argmax(output_tokens[0, -1, :])
        sampled_word = reverse_tgt_vocab_dict[sampled_word_idx]
        decoded_sentence += sampled_word

        # Exit condition: either hit max length
        # or find stop token.
        if (sampled_word == "<eos>" or (len(decoded_sentence) > tgt_max_seq_length)):
            stop_cond = True
        # Update the target sequence (of length 1)
        tgt_seq[0, 0] = sampled_word_idx

        # Update decoder states
        dec_states = [new_dec_h1, new_dec_c1, new_dec_h2, new_dec_c2]
    return decoded_sentence

结语

以上为推论阶段的程序,内容有一些小 bug 会再修正并更新。
今天的文章更新就先到这里,晚安!

阅读更多

  1. Attention Mechanisms in Recurrent Neural Networks (RNNs) With Keras

<<:  Day27:【技术篇】Webpack5 - Webpack 之运作流程

>>:  【D26】熟练一下厨具-bid and ask #1:什麽是选择权价差单

路由把关者- Navigation Guards

前言 Vue Router 提供 Navigation Guards,可以在路由变更前後去呼叫相关的...

Day 23 dio函数库

昨天提到Flutter最常用的网路函数库有HttpClient和http函数库,但其实还有一种叫做d...

AI新世界

人的科技文明发展始终来自於人性 在未来的世界当中,随着科技的发展及进步,AI越来越成熟,在很多人类的...

D16 - 如何用 Apps Script 自动化地创造与客制 Google Docs?(三)Element 的读取与创造

今天的目标 要怎麽简单快速地做出客制化地文件?今天,我们会教用 GAS 搭配 Goolge Doc。...

D14: 工程师太师了: 第7.5话

工程师太师了: 第7.5话 杂记: 注解是程序语言中用来解释程序码中的部分,可增加程序的可读性、可维...