其他
PyTorch 教程-CIFAR-10数据集上LeNet模型的测试
在前一个主题中,我们发现我们的LeNet模型与卷积神经网络能够对MNIST数据集图像进行分类。MNIST数据集包含灰度图像,但在CHIFAR-10数据集中,图像是彩色的且包含不同的物体。因此,我们最大的问题是我们的LeNet模型是否能够对CIFAR-10数据集的图像进行分类。我们将复制我们先前主题的代码,即CNN测试,并在代码的图像转换、实现、训练、验证和测试部分进行以下更改:
注意:如果您是新手,请务必了解我们先前主题的知识,以更有效地理解这个主题。
图像转换部分的更改:
在图像转换部分,我们将进行以下更改:
步骤1:
training_dataset=datasets.CIFAR10(root='./data',train=True,download=True,transform=transform1)
validation_dataset=datasets.CIFAR10(root='./data',train=False,download=True,transform=transform1)
步骤2:
transform1=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
现在,如果我们绘制我们的CIFAR-10图像,将得到以下输出:
步骤3:
classes={'plane','car','bird','cat','dear','dog'.'frog','horse','ship','truck'}
步骤4:
ax.set_title(classes[labels[idx].item()])
它将给出以下输出:
实现、训练和验证部分的更改:
我们的Lenet模型是为MNIST图像实现的。MNIST图像是灰度图像,但我们必须为包含彩色图像的CIFAR-10数据集实现我们的模型。因此,我们必须在代码中进行以下更改:
步骤1:
self.conv1=nn.Conv2d(3,20,5,1)
步骤2:
现在,我们必须训练大量的参数。经过5x5内核的卷积后,图像变为28x28,然后通过下一个池化变为14x14,执行具有相同大小内核的另一个卷积。图像再次减小为4x4,最后通过另一个最大池化,将馈送到完全连接的网络中的向量将是5x5x50。
self.fully1=nn.Linear(5*5*50,500)
步骤3:
xx=x.view(-1,5*5*50) #Reshaping the output into desired shape
现在,找到总损失和验证损失以及准确度和验证准确度并绘制,然后它将给出以下输出:
步骤4:
现在,我们将使用它来预测来自Web的图像,以简单地获取模型准确性的视觉透视。我们将使用以下图像:https://3c1703fe8d.site.internapcdn.net/newman/gfx/news/hires/2018/2-dog.jpg
当我们绘制这个图像时,它将显示为:
步骤5:
url='https://ichef.bbci.co.uk/news/912/cpsprodpb/160B4/production/_103229209_horsea.png'
response=requests.get(url,stream=True)
img=Image.open(response.raw)
img=transform1(img)
plt.imshow(im_convert(img))
在转换后,我们获得了图像的更抽象表示。它缩小为一个较小的32x32表示。
步骤6:
image1=img.to(device).unsqueeze(0)
output=model(image1)
_,pred=torch.max(output,1)
print(classes[pred.item()])
测试部分的更改:
dataiter=iter(validation_loader)
images,labels=dataiter.next()
imagesimages_=images.to(device)
labelslabels=labels.to(device)
output=model(images_)
_,preds=torch.max(output,1)
fig=plt.figure(figsize=(25,4))
for idx in np.arange(20):
ax=fig.add_subplot(2,10,idx+1,xticks=[],yticks=[])
plt.imshow(im_convert(images[idx]))
ax.set_title("{}({})".format(str(classes[preds[idx].item()]),str(classes[labels[idx].item())),color=("green" if preds[idx]==labels[idx] else "red"))
plt.show()