Pythonで機械学習 手書き数字認識 MNISTデータの準備 Part(8)

colaboratory

はじめに

こんにちは、swim-loverです。Pythonで機械学習の一つである手書き数字認識を実装しています。 Pythonを始めたばかりですが、「使いながら覚える」をコンセプトに勉強しています。 第7回は、 Pytorchを使ってMNISTデータの準備で登場したPython機能、関数について勉強しました。今回も続編を進めていきたいと思います。

MNISTデータスクリプト

前回、次のようなPythonコードを書きました。PytorchのDataLoaderについては、今後も多用すると思いますので、もうすこし掘り下げて確認しておきます。

import torch;
import torchvision
import matplotlib.pyplot as plt
import numpy
from torch.utils.data import DataLoader

train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True)
train_data = DataLoader(train_set,batch_size=None,shuffle=False)
data_iter = iter(train_data)
imgs, labels=next(data_iter)
plt.imshow(imgs, cmap='gray')
print("label numer is {}".format(labels))

DataLoader

MNISTデータセットは、60000個のデータの集まりです。機械学習では、データセットを小分けにするのをバッチ化と呼ぶようです。例えば、10個のデータを6000個といったイメージです。

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train

引数のbatch_sizeは、小分けの数を指定できるようです。batch_size=Noneの場合は小分けを作成しません。

次に小分けを作成してみます。

train_data = DataLoader(train_set,batch_size=10,shuffle=False)

batchを作成した場合、next()のところでエラーが発生してしまいました。

TypeError                                 Traceback (most recent call last)
<ipython-input-3-e1f9cb968349> in <module>()
     11 data_iter = iter(train_data)
     12 #first data
---> 13 imgs, labels=next(data_iter)
     14 
     15 fig = plt.figure()

色々しらべてみると、データセットのダウンロードに問題があるようです。成功例を調べてみるとtransforms引数を設定していることがわかりました。

train_set = torchvision.datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)

修正後のコードになります。

import torch;
import torchvision
from torchvision import transforms as transforms
import matplotlib.pyplot as plt
import numpy
from torch.utils.data import DataLoader
#train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True)
train_set = torchvision.datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
print(train_set)
print('----------------------------')
train_data = DataLoader(train_set,batch_size=32,shuffle=False)
tmp = train_data.__iter__()
#first data
imgs, labels = tmp.next()
print(imgs.size())
print(labels)

32個の画像(28 pixel x 28 pixel)が取得できるようになりました。

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data
    Split: Train
    StandardTransform
Transform: ToTensor()
----------------------------
torch.Size([32, 1, 28, 28])
tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
        1, 2, 4, 3, 2, 7, 3, 8])

先頭4つのデータを表示してみました。

fig = plt.figure()
imgplot = 1
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[0]),fontsize=20)
plt.imshow(imgs[0][0], cmap='gray')

imgplot = 2
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[1]),fontsize=20)
plt.imshow(imgs[1][0], cmap='gray')

imgplot = 3
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[2]),fontsize=20)
plt.imshow(imgs[2][0], cmap='gray')

imgplot = 4
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[3]),fontsize=20)
plt.imshow(imgs[3][0], cmap='gray')

まとめ

今回、pytorchのDataLoaderについて、バッチ化でデータロードをしてみました。エラーになやまされましたたが、”データダウンロード時にtransform=transforms.ToTensor()”を指定することでエラーが解消できました。

コメント

タイトルとURLをコピーしました