Pytorch实战——MNIST数据集手写数字识别
原视频链接:轻松学Pytorch手写字体识别MNIST
1.加载必要的库
1 2 3 4 5 6
| import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets,transforms
|
2.定义超参数
1 2 3 4
| BATCH_SIZE= 64 DEVICE=torch.device("cuda"if torch.cuda.is_available()else"cpu") EPOCHS=20
|
3.构建pipeline,对图像做处理
1 2 3 4 5
| pipeline=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,),(0.3081,)) ])
|
4.下载、加载数据集
1 2 3 4 5 6 7 8
| train_set=datasets.MNIST("data",train=True,download=True,transform=pipeline)
test_set=datasets.MNIST("data",train=False,download=True,transform=pipeline)
train_loader=DataLoader(train_set,batch_size=BATCH_SIZE,shuffle=True)
test_loader=DataLoader(test_set,batch_size=BATCH_SIZE,shuffle=True)
|
下载完毕之后,看一眼数据集内的图片
1 2 3 4 5 6 7 8 9 10 11 12 13
| with open("MNIST的绝对路径","rb") as f: file=f.read()
image1=[int(str(item).encode('ascii'),10)for item in file[16: 16+784]] print(image1)
import cv2 import numpy as np
image1_np=np.array(image1,dtype=np.uint8).reshape(28,28,1)
print(image1_np.shape) cv2.imwrite("digit.jpg",image1_np)
|
输出结果
5.构建网络模型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| class Digit(nn.Module): def __init__(self): super().__init__() self.conv1= nn.Conv2d(1,10,kernel_size=5) self.conv2=nn.Conv2d(10,20,kernel_size=3) self.fc1=nn.Linear(20*10*10,500) self.fc2=nn.Linear(500,10)
def forward(self,x): input_size=x.size(0) x=self.conv1(x) x=F.relu(x) x=F.max_pool2d(x,2,2)
x=self.conv2(x) x=F.relu(x)
x=x.view(input_size,-1)
x=self.fc1(x) x=F.relu(x) x=self.fc2(x)
ouput=F.log_softmax(x,dim=1) return ouput
|
6.定义优化器
1 2 3
| model =Digit().to(device)
optimizer=optim.Adam(model.parameters())
|
7.定义训练方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| def train_model(model,device,train_loader,optimizer,epoch): model.train() for batch_index,(data,target) in enumerate(train_loader): data,target=data.to(device), target.to(device) optimizer.zero_grad() output=model(data) loss = F.cross_entropy(output,target) loss.backward() optimizer.step() if batch_index%3000==0: print("Train Epoch :{} \t Loss :{:.6f}".format(epoch,loss.item()))
|
8.定义测试方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| def test_model(model,device,test_loader): model.eval() corrcet=0.0 test_loss=0.0 with torch.no_grad(): for data,target in test_loader: data,target=data.to(device),target.to(device) output=model(data) test_loss+=F.cross_entropy(output,target).item() pred=output.argmax(1) corrcet+=pred.eq(target.view_as(pred)).sum().item() test_loss/=len(test_loader.dataset) print("Test--Average Loss:{:.4f},Accuarcy:{:.3f}\n".format(test_loss,100.0 * corrcet / len(test_loader.dataset)))
|
9.调用方法
1 2 3 4
| for epoch in range(1,EPOCHS+1): train_model(model,DEVICE,train_loader,optimizer,epoch) test_model(model,DEVICE,test_loader)
|
输出结果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
| Train Epoch :1 Loss :2.296158 Train Epoch :1 Loss :0.023645 Test--Average Loss:0.0027,Accuarcy:98.690
Train Epoch :2 Loss :0.035262 Train Epoch :2 Loss :0.002957 Test--Average Loss:0.0027,Accuarcy:98.750
Train Epoch :3 Loss :0.029884 Train Epoch :3 Loss :0.000642 Test--Average Loss:0.0032,Accuarcy:98.460
Train Epoch :4 Loss :0.002866 Train Epoch :4 Loss :0.003708 Test--Average Loss:0.0033,Accuarcy:98.720
Train Epoch :5 Loss :0.000039 Train Epoch :5 Loss :0.000145 Test--Average Loss:0.0026,Accuarcy:98.840
Train Epoch :6 Loss :0.000124 Train Epoch :6 Loss :0.035326 Test--Average Loss:0.0054,Accuarcy:98.450
Train Epoch :7 Loss :0.000014 Train Epoch :7 Loss :0.000001 Test--Average Loss:0.0044,Accuarcy:98.510
Train Epoch :8 Loss :0.001491 Train Epoch :8 Loss :0.000045 Test--Average Loss:0.0031,Accuarcy:99.140
Train Epoch :9 Loss :0.000428 Train Epoch :9 Loss :0.000000 Test--Average Loss:0.0056,Accuarcy:98.500
Train Epoch :10 Loss :0.000001 Train Epoch :10 Loss :0.000377 Test--Average Loss:0.0042,Accuarcy:98.930
|
总结和改进
看完视频之后,老师确实讲得好,但是却没有讲明白为什么网络结构为什么要这样搭建,于是我又去看了看CNN,这个网络结构也能实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
| class Digit(nn.Module): def __init__(self): super().__init__() self.conv1= nn.Conv2d(1,10,5) self.conv2=nn.Conv2d(10,20,5) self.fc1=nn.Linear(20*4*4,10)
def forward(self,x): input_size=x.size(0) x=self.conv1(x) x=F.relu(x) x=F.max_pool2d(x,2,2)
x=self.conv2(x) x=F.relu(x) x=F.max_pool2d(x,2,2)
x=x.view(input_size,-1)
x=self.fc1(x) x=F.relu(x)
ouput=F.log_softmax(x,dim=1) return ouput
|
用CNN之后,发现准确度一下子下降到了50%,百思不得其解,我猜可能是优化器的问题,就把优化器换成了SGD,结果果然效果更好
1 2 3
| Train Epoch :17 Loss :0.014693 Train Epoch :17 Loss :0.000051 Test--Average Loss:0.0026,Accuarcy:99.010
|
在第17轮准确率居然到了99%,不知道为什么,先挖个坑,等我以后研究明白再来填