[Day12] 以神经网络进行时间序列预测 — LSTM

本篇详细介绍 LSTM 及如何以 LSTM 建模预测时间序列。

本日大纲

  • LSTM 介绍
    • LSTM 元件构成
    • LSTM 的分类
  • 实作注意事项
    • 资料集介绍/目标
    • 训练格式转换
    • LSTM 架构参考

LSTM 介绍

长短期记忆模型(Long short-term memory,LSTM)为一特殊的 RNN 模型(递归神经网络)
目的是要解决「长序列」训练过程中的梯度消失梯度爆炸问题,因此相比普通的 RNN,LSTM 能够在更长的序列中有较好的表现。

LSTM 元件构成

LSTM 元件与 RNN 不同的地方在於:

  • state 数量

    • RNN:1 个传输状态 — hidden state https://chart.googleapis.com/chart?cht=tx&chl=h%5Et
    • LSTM:2 个传输状态 — cell state https://chart.googleapis.com/chart?cht=tx&chl=c%5Et , hidden state https://chart.googleapis.com/chart?cht=tx&chl=h%5Et
  • LSTM cell state 的更新阶段(门控)
    LSTM_cell_gates.png
    可以把它想成一个记忆区,由这三种门控决定什麽资料要被模型记忆并更新到 cell state 传输至下一层。

    1. input gate:决定前一层哪些资讯可以进到这个记忆区
    2. forget gate:决定目前记忆区中的哪些资讯不要保存(权重不高、相对不重要的)
    3. output gate:最後决定哪些资讯要离开目前记忆区,更新目前的 cell state 并传输到下一层 LSTM

LSTM 的分类

复习一下在 Day10 曾经列过的分类:

  • Vanilla LSTM:单一层 LSTM
  • Stacked LSTM:叠两层以上的 LSTM,要将第一层的参数 return_sequence 设为 True
  • Bidirectional LSTM:使序列同时以顺向和反向输入模型

时间序列预测实作

资料集介绍

股价包含以下属性:

  • Open:当日开盘价
  • Close:当日收盘价
  • High:当日最高价
  • Low:当日最低价

目标

透过神经网络训练捕捉时序特徵,预测未来时间点的均价(当日最高价与最低价之平均)

训练格式转换

基本上前处理就是看不同资料集怎麽做,在递归神经网络这边比较重点的是训练格式的转换。
要根据每个「要被预测的时间点 (output)」制作 input data,使用滑动视窗的概念,形成长度相同的许多训练样本,代表各个「要被预测的时间点」前「N 个时间点」的资讯。
因为样本是由不同时间窗格中的序列所组成,LSTM 能够去学习前一段时间和所需预测时间点之间的时序关系。

滑动视窗的概念,可以参考下面这张示意图:
sliding_window.jpg

图片来源:researchgate.net

LSTM 架构参考

def build_model():
    model = Sequential()
    model.add(LSTM(128, input_shape=(74, 30), return_sequence=True))
    model.add(Dropout(0.3))
    model.add(LSTM(32))
    model.add(Dropout(0.15))
    model.add(Dense(1, activation='sigmoid'))
    model.compile(optimizer='adam', 
                  loss='mse', 
                  metrics=['mae'])
    model.summary()
    
    return model

(还在更新中)


<<:  【在厨房想30天的演算法】Day 12 资料结构:杂凑表 Hash Table

>>:  [Day 12]从零开始学习 JS 的连续-30 Days---DOM是什麽?

Day 05 GPIO peripherals

Control GPIO peripherals using digital input/outpu...

# Day 6 Supporting PMUs on RISC-V platforms (二)

今天一样是 Supporting PMUs on RISC-V platforms 相关的内容,先来...

【系统程序】1-3简化指令电脑(SIC)

1-3简化指令电脑(SIC) 简介 简化指令电脑(Simplified Instructional ...

goroutine

Golang goroutine 我自己理解goroutine 就很类似其他语言的thread[备注...

[Day19] 参数(上)

前面讲 函式 function 时提到参数,回头看自己打的文章发现错误的地方修正了一下。 Param...