當前位置:首頁 > IT技術(shù) > 其他 > 正文

32 模型的測試
2022-05-29 22:43:47

利用已經(jīng)訓(xùn)練好的模型,給它提供輸入,看輸出

一、輸入

1.找到一張圖片

image

  • 對于圖片要進行轉(zhuǎn)換

png格式是四個通道,除了RGB三通道外,還有一個透明度通道,通過下面的命令,可以適應(yīng)png、jpg各種格式的圖片。

命令:

image=image.convert('RGB')

二、實驗對應(yīng)的部分

1.代碼

import torch
import torchvision
from PIL import Image #直接從PIL中導(dǎo)入Image,不是從PIL.Image

# 1. 圖片的輸入
from torch import nn

image_path="./images/plane.png"
image=Image.open(image_path)
print(image)

# 2. 轉(zhuǎn)換為三通道
image=image.convert('RGB')

# 3. 將PIL格式轉(zhuǎn)變?yōu)閠ensor數(shù)據(jù)類型
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                          torchvision.transforms.ToTensor()])
image=transform(image)
print(image.shape)

# 4.引入模型結(jié)構(gòu):也可以單獨建立文件夾,利用import
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1=nn.Sequential(
            nn.Conv2d(3, 32, 5, stride=1,padding=2),  # in_channel 3;out_channel 32;kernel 5;padding需要計算(一般不會太大)
            nn.MaxPool2d(2),  # kennel_Size=2
            nn.Conv2d(32, 32, 5,stride=1,padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5,stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),  # 展平 :可以把后面的刪掉 獲得輸出的大小
            nn.Linear(64*4*4, 64),  # 看上一層的大小64*4*4=1024 ,可以看到是1024
            nn.Linear(64, 10)  # 輸入大小 64 輸出大小 10
        )
    def forward(self,x):
            x=self.model1(x)
            return x

# 5. 使用訓(xùn)練好的模型
model=torch.load("tudui_49_GPU.pth") #由GPU訓(xùn)練
# model=torch.load("tudui_49_GPU.pth",map_location=torch.device("cpu")) #由GPU模型映射到CPU模型上,后面的image不加cuda了
print(model)

# 6. 輸出
image=torch.reshape(image,(1,3,32,32)) #轉(zhuǎn)變通道數(shù)
image=image.cuda() #因為模型是GPU訓(xùn)練的,數(shù)據(jù)集也要用,要是CPU訓(xùn)練的模型,就去掉這個命令
model.eval()
with torch.no_grad(): #節(jié)省步驟,提升性能
  output=model(image)
print(output)
print(output.argmax(1)) #橫向比較

2.輸出對應(yīng)的類別

image

3.運行結(jié)果:

利用colab運行50輪后的結(jié)果:測試集準確度0.65

預(yù)測狗
image
預(yù)測飛機
image

本文摘自 :https://www.cnblogs.com/

開通會員,享受整站包年服務(wù)立即開通 >