背景
这是“实例学PyTorch”系列的第5篇文章。在第4篇文章“实例学PyTorch(4):序列预测(一)——循环神经网络(RNN)”中,我们介绍了序列预测问题,以及如何使用一个简单的循环神经网络(RNN)来实现对正弦函数的预测。在本文中,我们将更进一步,介绍序列预测中另外两种常用的神经网络:门控循环单元(Gated Recurrent Unit,GRU)和长短期记忆网络(Long Short-Term Memory,LSTM)。
本文的代码可以在我的GitHub仓库https://github.com/jin-li/pytorch-tutorial中的T05_series_rnn
文件夹中找到。
RNN的问题与解决方法
在第4篇文章中,我们介绍了循环神经网络(RNN)的基本概念和工作原理。RNN是一种可以处理序列数据的神经网络,它在每个时间步都会保存之前的数据信息,从而可以处理序列数据。但是,RNN也有一些问题,例如梯度消失和梯度爆炸问题。梯度消失和梯度爆炸是深度学习中的一个常见问题,它们会导致模型无法收敛,或者收敛速度非常慢。
梯度消失和梯度爆炸是深度学习中的一个常见问题。在反向传播算法中,梯度是通过链式法则计算的:
$D_n = \sigma^{’}(z_1) w_1 \cdot \sigma^{’}(z_2) w_2 \cdot \ldots \cdot \sigma^{’}(z_{n-1}) w_{n-1} \cdot \sigma^{’}(z_n) w_n$
其中,$D_n$是第$n$层的梯度,$\sigma^{’}(z_i)$是第$i$层的激活函数的导数,$w_i$是第$i$层的权重。可以看到,梯度是通过每一层的激活函数的导数和权重相乘得到的。如果激活函数的导数小于1,那么梯度会随着层数的增加指数级地减小,导致梯度消失;如果激活函数的导数大于1,那么梯度会随着层数的增加指数级地增大,导致梯度爆炸。
长短期记忆网络(LSTM)和门控循环单元(GRU)是为了解决RNN中的梯度消失和梯度爆炸问题而提出的。它们通过引入门控机制来控制信息的流动,从而解决了RNN中的长期依赖问题。
LSTM和GRU简介
LSTM和GRU都引入了门控机制来控制信息的流动。所谓门控机制,就是把数据乘以一个0到1之间的系数,从而控制是否传递数据以及传递多大比例的数据,这个系数是由一个sigmoid激活函数计算得来的。
LSTM和GRU的区别在于LSTM有三个门:遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate),而GRU只有两个门:重置门(Reset Gate)和更新门(Update Gate)。GRU相对于LSTM来说,参数更少,计算量更小,但LSTM的表现一般比GRU更好。
长短期记忆网络(LSTM)
长短期记忆网络(Long Short-Term Memory,LSTM)是一种门控循环神经网络,由Hochreiter和Schmidhuber于1997年提出。LSTM引入了三个门:遗忘门(Forget Gate)、输入门(Input Gate)和输出门(Output Gate),通过这三个门来控制信息的流动,从而解决了RNN中的长期依赖问题。
LSTM的结构如下图所示:
遗忘门、输入门和输出门的具体计算公式的推导这里不再给出,感兴趣的读者可以参考这篇文章Understanding LSTM Networks。我们这里只简单介绍一下LSTM的工作原理:
我们假设LSTM中的记忆数据是$C_t$,隐藏状态是$h_t$,输入数据是$x_t$,遗忘门是$f_t$,输入门是$i_t$,输出门是$o_t$。LSTM的工作原理如下:
-
遗忘门:遗忘门的输入是当前时间步的输入数据$x_t$和上一个时间步的隐藏状态$h_{t-1}$,这两个输入数据经过sigmoid函数后的输出是一个0到1之间的系数$f_t$。$f_t$决定了上一个时间步的数据需要保留多少,如果$f_t$接近0,那么上一个时间步的数据就会被遗忘;如果$f_t$接近1,那么上一个时间步的数据就会被保留。
-
输入门:输入门的输入是当前时间步的输入数据$x_t$和上一个时间步的隐藏状态$h_{t-1}$,这两个输入数据经过sigmoid函数后的输出是一个0到1之间的系数$i_t$。$i_t$决定了当前时间步的输入数据需要保留多少,如果$i_t$接近0,那么当前时间步的输入数据就会被忽略;如果$i_t$接近1,那么当前时间步的输入数据就会被保留。
-
更新记忆:更新记忆的公式是$C_t = f_t \cdot C_{t-1} + i_t \cdot \tilde{C}t$,其中$\tilde{C}t$是当前时间步的输入数据$x_t$和上一个时间步的隐藏状态$h{t-1}$经过tanh函数后的输出。$C_t$ 是当前时间步的记忆数据,$f_t \cdot C{t-1}$是上一个时间步的记忆数据,$i_t\cdot\tilde{C}_t$是当前时间步的输入数据。
-
输出门:输出门的输入是当前时间步的输入数据$x_t$和上一个时间步的隐藏状态$h_{t-1}$,这两个输入数据经过sigmoid函数后的输出是一个0到1之间的系数$o_t$。$o_t$决定了当前时间步的输出数据$h_t$需要保留多少,如果$o_t$接近0,那么当前时间步的输出数据就会被忽略;如果$o_t$接近1,那么当前时间步的输出数据就会被保留。
门控循环单元(GRU)
门控循环单元(Gated Recurrent Unit,GRU)在LSTM的基础上做了一些简化,由Cho等人于2014年提出。GRU只有两个门:重置门(Reset Gate)和更新门(Update Gate)。相比于LSTM,GRU的参数更少,计算量更小。虽然LSTM的表现一般更好,但GRU由于其简单性也颇受欢迎。
GRU中的重置门和更新门实际上是LSTM中三个门的简化版,其具体计算公式的推导这里不再给出,感兴趣的读者可以参考这篇文章Understanding LSTM Networks。我们这里只简单介绍一下GRU的工作原理:
我们假设GRU中的记忆数据是$h_t$,输入数据是$x_t$,重置门是$r_t$,更新门是$z_t$。GRU的工作原理如下:
-
重置门:重置门的输入是当前时间步的输入数据$x_t$和上一个时间步的隐藏状态$h_{t-1}$,这两个输入数据经过sigmoid函数后的输出是一个0到1之间的系数$r_t$。$r_t$决定了上一个时间步的数据需要保留多少,如果$r_t$接近0,那么上一个时间步的数据就会被忽略;如果$r_t$接近1,那么上一个时间步的数据就会被保留。
-
更新记忆:更新记忆的公式是$\tilde{h}t = \tanh(W [r_t h{t-1}, x_t]) = \tanh(W_{xh} x_t + r_t \odot W_{hh} h_{t-1})$,其中$\odot$是元素乘法。$\tilde{h}t$是当前时间步的记忆数据,$W{xh} x_t$是当前时间步的输入数据,$r_t \odot W_{hh} h_{t-1}$是上一个时间步的隐藏状态。
-
更新门:更新门的输入是当前时间步的输入数据$x_t$和上一个时间步的隐藏状态$h_{t-1}$,这两个输入数据经过sigmoid函数后的输出是一个0到1之间的系数$z_t$。$z_t$决定了当前时间步的记忆数据需要保留多少,如果$z_t$接近0,那么当前时间步的记忆数据就会被忽略;如果$z_t$接近1,那么当前时间步的记忆数据就会被保留。
使用LSTM和GRU实现序列预测的代码
数据准备
我们继续使用上一篇文章中生成的正弦序列数据,具体参见“实例学PyTorch(4):使用循环神经网络实现序列预测(一)”。
定义模型
LSTM模型
PyTorch中已经实现了LSTM模型,我们这里封装一下以用于本问题。代码如下:
|
|
GRU模型
PyTorch中也已经实现了GRU模型,我们这里封装一下以用于本问题。代码如下:
|
|
训练模型
在上一篇文章中使用的代码的基础上,我们只需要做一些简单的修改即可。
首先我们把上面定义的LSTM和GRU模型放入主文件中,然后修改模型调用的部分。这里我们给main()
函数添加一个命令行参数model_type
,用于指定使用RNN、LSTM还是GRU模型:
|
|
为了对比三种模型的性能,我们另写一个Python脚本来调用这三种模型。之前的代码用的是命令行参数,这里我们将参数改为函数参数,并让函数返回训练中的损失值、测试的准确率,以及模型的预测结果。这样我们就可以在调用函数的脚本中绘制模型的性能曲线。
修改后的主脚本参见GitHub仓库T05_series_gru_lstm
文件夹中的time_series_models.py
文件。比较性能的脚本参见T05_series_gru_lstm
文件夹中的compare_results.py
文件。
运行代码
我们运行比较性能的脚本,可以得到如下结果:
-
三种模型的损失值曲线:
-
三种模型的预测结果:
compare_results.py
脚本依次运行了RNN、LSTM和GRU三种模型,如果在我的电脑上用GPU(NVIDIA GeForce GTX 4060 Ti)运行,总运行时间约30秒,最高显存占用约354 MB。若使用CPU(Intel i5-9600K)运行,总运行时间约3分33秒。
我运行了这个比较性能的脚本多次,每次的结果都有一些差异。大部分情况下,三个模型都能够很好地拟合正弦函数的序列数据,但总体上LSTM和GRU的表现要略好于RNN。
总结
在本文中,我们介绍了门控循环单元(GRU)和长短期记忆网络(LSTM)的工作原理,并使用PyTorch实现了这两种模型。我们使用这两种模型来实现对正弦函数的序列预测,并与简单循环神经网络(RNN)进行了比较。我们发现,LSTM和GRU相对于RNN来说,能够更好地捕捉序列数据中的长期依赖关系,从而提高了模型的性能。