背景
这是“实例学PyTorch”系列的第3篇文章。在前两篇文章“实例学PyTorch(1):MNIST手写数字识别(一)——PyTorch基础和神经网络基础”和“实例学PyTorch(2):MNIST手写数字识别(二)——神经网络中的参数选择”中,我们介绍了PyTorch的基本概念和使用方法,使用一个简单的三层全连接神经网络实现了MNIST手写数字识别,并简单讨论了一下神经网络中参数选择的问题,这算是深度学习或计算机视觉领域的“Hello World"。在这篇文章中,我们将使用一个卷积神经网络实现MNIST手写数字识别,算是对之前两篇文章的一个扩展。
本文的代码可以在我的GitHub仓库https://github.com/jin-li/pytorch-tutorial中的T03_mnist_cnn
文件夹中找到。该代码是基于PyTorch官方的示例代码https://github.com/pytorch/examples改编的。
卷积神经网络(CNN)简介
在前两篇文章中,我们用了一个简单的全连接神经网络来解决MNIST手写数字识别问题。全连接神经网络的效果已经不错了,但细想一下,全连接神经网络有一个很大的缺点,就是没有考虑到图像的局部特征。在图像识别中,图像的局部特征是非常重要的,例如图像的边缘、纹理等。卷积神经网络(Convolutional Neural Network,CNN)是一种专门用于处理图像的神经网络,它可以有效地提取图像的局部特征,从而提高图像识别的准确度。
卷积神经网络的关键点在于卷积层(Convolutional Layer)和池化层(Pooling Layer)。卷积层是用一个卷积核(Convolutional Kernel)对输入图像进行卷积操作,从而提取图像的局部特征。池化层是用一个池化核(Pooling Kernel)对卷积层的输出进行池化操作,从而减少特征图的大小,提高计算效率。卷积层和池化层交替出现,最后通过全连接层得到输出。接下来我们将设计一个简单的卷积神经网络来实现MNIST手写数字识别。
卷积核是卷积神经网络的核心,它是一个小的矩阵,用来提取图像的局部特征。卷积核的大小、步长、填充等参数都是需要调整的超参数。卷积核的大小决定了卷积核能感受的区域的大小,即局部感受野。步长是卷积核每次移动的距离,填充是在图像周围填充一圈0,可以保持图像的大小不变。池化核的大小、步长等参数也是需要调整的超参数。关于卷积核的更多信息可以参考这篇文章。这里我们引用这篇文章中的几个图来说明卷积核的工作原理:
卷积计算 | 普通卷积 | 空洞卷积 | 反卷积 |
MNIST手写数字识别
CNN模型设计
-
我们可以先确定输入和输出。显然这个神经网络的输入是28x28的灰度图像,输出是0到9之间的一个数字。
-
我们需要选择一个神经网络类型,例如全连接神经网络、卷积神经网络、循环神经网络等。这里我们选择使用一个简单的卷积神经网络。
-
我们需要确定网络的结构,包括网络的层数、每层的神经元数、激活函数等。这里我们选择一个简单的卷积神经网络,包括两个卷积层和两个全连接层。
代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output
这里我们定义了一个名为
Net
的类,继承自nn.Module
。在__init__
方法中,我们定义了两个卷积层conv1
和conv2
,和两个全连接层fc1
和fc2
。网络的结构如下图所示:- 由于MNIST数据集是灰度图像,所以输入通道数是1。如果是RGB彩色图像,输入通道数是3。输出通道数表示卷积核的个数,即每个卷积层想要提取的特征数,这里假设第一个卷积层提取32个特征,第二个卷积层提取64个特征。卷积核的大小是每个卷积核能感受的区域的大小,即局部感受野,这里假设卷积核大小是3x3。步长是卷积核每次移动的距离,这里假设步长是1。
- 28x28的输入经过步长为1的3x3卷积核,输出是26x26;再经过第二个卷积层,输出是24x24;再经过最大池化层,输出是12x12。两个dropout层是为了防止过拟合,不影响输出参数的形状。因此,第一个全连接层的输入神经元数是12x12x64=9216。
- 中间的Dropout层是为了防止过拟合,Dropout是一种正则化方法,可以随机地将一些神经元的输出设置为0,从而减少神经元之间的依赖关系。
- 经过两个全连接层,最后输出10个神经元,分别表示0到9之间的数字。最后使用
F.log_softmax()
函数将输出转换为概率。 - 在
forward
方法中,我们定义了网络的前向传播过程,即输入数据经过每一层的计算,最后输出预测结果。其中:- 两个卷积层都使用ReLU激活函数,第二个卷积层后面跟了一个最大池化层。
- 然后执行一次Dropout操作,将输出展平为一维向量,输入到两个全连接层中。
- 在第一个全连接层后面又执行了一次ReLU激活函数,然后再执行一次Dropout操作。最后输出10个神经元,分别表示0到9之间的数字。
数据加载、预处理、训练、测试
这部分内容其实和上一篇文章中介绍的简单全连接神经网络是一样的,只是需要把上一篇文章中定义的SimpleNet
类换成这里定义的Net
类。
这里的代码实际上就是PyTorch官方给的示例,可以参考这里,或者在我的GitHub仓库https://github.com/jin-li/pytorch-tutorial中的T03_mnist_cnn
文件夹中找到。
这里我们创建了一个Python脚本来运行这个CNN模型,并绘制模型的性能曲线。
本文代码所需的Python环境和之前两篇文章是一样的,可以通过conda activate pytorch-mnist
激活环境,然后使用如下命令运行代码:
|
|
运行结果
我分别使用CPU(Intel i5-9600K)和GPU(NVIDIA GeForce RTX 4060 Ti)运行了这个CNN模型,运行时间分别为12分56秒和2分25秒。可以看到,使用GPU运行速度快了很多。运行所需的GPU资源实际并不高,GPU占用率约12%,显存占用率约740MB。
模型的表现如下图所示:
可见,这个CNN模型在MNIST数据集上的准确率约为99.2%,比之前的全连接神经网络模型要好。
使用模型进行数字识别
本文已经是介绍MNIST手写数字识别的第三篇文章了,但是我们一直都只是在训练模型,没有实际使用我们训练好的模型。现在我们就来使用我们训练好的CNN模型来识别一些手写数字。
-
首先我们在训练模型时需要把模型的参数保存下来,这样我们在待会儿使用模型时就可以直接加载这些参数,而不必重新训练。保存模型参数的代码如下:
1
torch.save(model.state_dict(), "mnist_cnn.pt")
在
mnist_cnn.py
的代码中已经有这个功能了,我们需要在运行mnist_cnn.py
时指定--save-model
参数,这样模型参数就会被保存到mnist_cnn.pt
文件中:1
python mnist_cnn.py --save-model
-
然后我们可以使用保存的模型参数来识别手写数字。我们可以使用
PIL
库来读取图片,然后使用torchvision
库来对图片进行预处理,最后使用我们训练好的CNN模型来识别图片中的数字。需要注意的是:
- 我们这里使用的MNIST训练集中的图片是黑底白字的28x28像素的图片,所以我们在创建手写图片时也要保持黑底白字。
- 我们自己手写的数字不一定是28x28像素的,所以我们需要对图片进行缩放,使其变成28x28像素的图片。
- MNIST训练集中的图片满足均值为0.1307,标准差为0.3081的正态分布,所以我们在对图片进行预处理时需要对图片进行归一化。
做这些处理的代码如下:
1 2 3 4 5 6 7 8
from torchvision import transforms transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
完整的代码在
classify_image
文件中。 -
这里我在电脑上使用鼠标手写了几个数字,对每个数字分别截图保存到
T03_mnist_cnn
文件夹中numbers
目录下: -
最后我们可以使用
classify_image.py
来识别一个手写数字:1
python classify_image.py numbers/number1.png
-
我分别识别了这10个手写数字,发现模型正确识别出了8个数字,没识别出的两个是数字
2
和9
,模型将它们分别识别成了4
和8
,虽然有些令我费解,但这个模型的表现也称得上是差强人意了。
总结
在这篇文章中,我们使用了一个卷积神经网络来实现MNIST手写数字识别,相比之前的全连接神经网络,CNN模型的准确率有了明显的提升。我们还使用了训练好的CNN模型来识别一些手写数字,模型的表现也还算不错。
MNIST作为计算机视觉领域的“Hello World”,是一个非常经典的数据集,也是一个非常好的入门数据集。我们已经写了三篇文章来讨论它,这里不妨先告一段落,先来研究一些其他机器学习或深度学习的问题,等以后有机会我们再来使用MNIST数据集来研究一些其他的深度学习算法。