Background
This is the third article in the “Learn PyTorch by Examples” series. In the previous two articles “Learn PyTorch (1): PyTorch Basics and MNIST Handwritten Digit Recognition (I)” and “Learn PyTorch (2): Parameter Selection in MNIST Handwritten Digit Recognition (II)”, we introduced the basic concepts and usage of PyTorch, and implemented MNIST handwritten digit recognition using a simple three-layer fully connected neural network. In this article, we will use a convolutional neural network to implement MNIST handwritten digit recognition, which is an extension of the previous two articles.
The code for this article can be found in the T03_mnist_cnn
folder in my GitHub repository https://github.com/jin-li/pytorch-tutorial. The code is adapted from PyTorch’s official example code https://github.com/pytorch/examples.
Convolutional Neural Networks (CNNs) Introduction
In the previous two articles, we used a simple fully connected neural network to solve the MNIST handwritten digit recognition problem. The fully connected neural network performed well, but it did not consider the local features of the image. In image recognition, local features of the image are very important, such as edges, textures, etc. Convolutional Neural Networks (CNNs) are neural networks specifically designed to process images. They can effectively extract local features of the image, thereby improving the accuracy of image recognition.
The key points of convolutional neural networks are the convolutional layer and the pooling layer. The convolutional layer convolves the input image with a convolutional kernel to extract local features of the image. The pooling layer pools the output of the convolutional layer using a pooling kernel to reduce the size of the feature map and improve computational efficiency. The convolutional layer and the pooling layer appear alternately, and the output is obtained through the fully connected layer. Next, we will design a simple convolutional neural network to implement MNIST handwritten digit recognition.
The convolutional kernel is the core of the convolutional neural network. It is a small matrix used to extract local features of the image. The size, stride, padding, and other parameters of the convolutional kernel are hyperparameters that need to be adjusted. The size of the convolutional kernel determines the size of the area the convolutional kernel can sense, i.e., the local receptive field. The stride is the distance the convolutional kernel moves each time, and padding is to pad a circle of 0 around the image to keep the size of the image unchanged. The size and stride of the pooling kernel are also hyperparameters that need to be adjusted. For more information about the convolutional kernel, you can refer to this article. Here we use a few figures from this article to illustrate the working principle of the convolutional kernel:
Convolution Calculation | Normal Convolution | Dilated Convolution | Deconvolution |
MNIST Handwritten Digit Recognition
CNN Model Design
-
We can first determine the input and output. Obviously, the input of this neural network is a 28x28 grayscale image, and the output is a number between 0 and 9.
-
We need to choose a neural network type, such as a fully connected neural network, a convolutional neural network, a recurrent neural network, etc. Here we choose to use a simple convolutional neural network.
-
We need to determine the structure of the network, including the number of layers, the number of neurons in each layer, the activation function, etc. Here we choose a simple convolutional neural network, including two convolutional layers and two fully connected layers.
The code is as follows:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout(0.25) self.dropout2 = nn.Dropout(0.5) self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) output = F.log_softmax(x, dim=1) return output
Here we define a class named
Net
that inherits fromnn.Module
. In the__init__
method, we define two convolutional layersconv1
andconv2
, and two fully connected layersfc1
andfc2
. The structure of the network is as follows:- Since the MNIST dataset is a grayscale image, the input channel is 1. If it is an RGB color image, the input channel is 3. The output channel represents the number of convolution kernels, i.e., the number of features that each convolutional layer wants to extract. Here we assume that the first convolutional layer extracts 32 features, and the second convolutional layer extracts 64 features. The size of the convolution kernel is the size of the area that each convolution kernel can sense, i.e., the local receptive field. Here we assume that the size of the convolution kernel is 3x3. The stride is the distance the convolution kernel moves each time, here we assume the stride is 1.
- The 28x28 input passes through a 3x3 convolution kernel with a stride of 1, the output is 26x26; then through the second convolutional layer, the output is 24x24; then through the max pooling layer, the output is 12x12. Two dropout layers are used to prevent overfitting, without affecting the shape of the output parameters. Therefore, the number of input neurons in the first fully connected layer is 12x12x64=9216.
- The middle Dropout layer is used to prevent overfitting. Dropout is a regularization method that randomly sets some neuron outputs to 0, reducing the dependence between neurons.
- After two fully connected layers, the last output is 10 neurons, representing the numbers 0 to 9. Finally, the
F.log_softmax()
function is used to convert the output to probabilities. - In the
forward
method, we define the forward propagation process of the network, i.e., the input data passes through each layer for calculation, and finally outputs the prediction result. Where:- Both convolutional layers use the ReLU activation function, and the second convolutional layer is followed by a max pooling layer.
- Then perform a Dropout operation, flatten the output into a one-dimensional vector, and input it into two fully connected layers.
- After the first fully connected layer, another ReLU activation function is performed, and then another Dropout operation is performed. Finally, output 10 neurons, representing the numbers 0 to 9.
Data Loading, Preprocessing, Training, and Testing
This part is actually the same as the simple fully connected neural network introduced in the previous article, except that the SimpleNet
class defined in the previous article needs to be replaced with the Net
class defined here.
The code here is actually the example code provided by PyTorch, which can be found here, or in the T03_mnist_cnn
folder in my GitHub repository https://github.com/jin-li/pytorch-tutorial.
Here we create a Python script to run this CNN model and plot the performance curve of the model.
The Python environment required for the code in this article is the same as the previous two articles. You can activate the environment with conda activate pytorch-mnist
and then run the code with the following command:
|
|
Running Results
I ran this CNN model using both CPU (Intel i5-9600K) and GPU (NVIDIA GeForce RTX 4060 Ti), and the running time was 12 minutes and 56 seconds and 2 minutes and 25 seconds, respectively. It can be seen that the running speed is much faster using the GPU. The GPU resources required for running are actually not high, with a GPU utilization rate of about 12% and a memory utilization rate of about 740MB.
The performance of the model is as follows:
It can be seen that the accuracy of this CNN model on the MNIST dataset is about 99.2%, which is higher than the simple fully connected neural network.
Using the Model for Digit Recognition
This article is already the third article on MNIST handwritten digit recognition, but we have only been training the model and have not actually used the trained model. Now we will use the trained CNN model to recognize some handwritten digits.
-
First, when training the model, we need to save the model parameters so that we can load these parameters directly when using the model, without having to retrain the model. The code to save the model parameters is as follows:
1
torch.save(model.state_dict(), "mnist_cnn.pt")
The code in
mnist_cnn.py
already has this functionality. We need to specify the--save-model
parameter when runningmnist_cnn.py
, so that the model parameters will be saved to themnist_cnn.pt
file:1
python mnist_cnn.py --save-model
-
Then we can use the saved model parameters to recognize handwritten digits. We can use the
PIL
library to read images, use thetorchvision
library to preprocess images, and use our trained CNN model to recognize the digits in the images.It should be noted that:
- The images in the MNIST training set are black background white text 28x28 pixel images, so we need to keep the images we create black background white text.
- The digits we write ourselves are not necessarily 28x28 pixels, so we need to scale the images to 28x28 pixels.
- The images in the MNIST training set satisfy a normal distribution with a mean of 0.1307 and a standard deviation of 0.3081, so we need to normalize the images during preprocessing.
The code for these operations is as follows:
1 2 3 4 5 6 7 8
from torchvision import transforms transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])
The complete code is in the
classify_image
file. -
Here I used the mouse to write a few digits on the computer, and saved each digit separately in the
T03_mnist_cnn
folder in thenumbers
directory: -
Finally, we can use
classify_image.py
to recognize a handwritten digit:1
python classify_image.py numbers/number1.png
-
I recognized these 10 handwritten digits and found that the model correctly recognized 8 digits. The two digits that were not recognized were the digits
2
and9
, which the model recognized as4
and8
, respectively. Although some of the results were puzzling, the model’s performance was still acceptable.
Summary
In this article, we used a convolutional neural network to implement MNIST handwritten digit recognition. Compared to the simple fully connected neural network, the CNN model achieved a significant improvement in accuracy. We also used the trained CNN model to recognize some handwritten digits, and the model’s performance was quite good.
MNIST is a classic dataset in the field of computer vision and a good introductory dataset. We have written three articles to discuss it, so let’s take a break here and study some other machine learning or deep learning problems first. If we have the opportunity in the future, we will use the MNIST dataset to study some other deep learning algorithms.