はじめに
こんにちは、swim-loverです。Pythonで機械学習の一つである手書き数字認識を実装しています。 Pythonを始めたばかりですが、「使いながら覚える」をコンセプトに勉強しています。 第7回は、 Pytorchを使ってMNISTデータの準備で登場したPython機能、関数について勉強しました。今回も続編を進めていきたいと思います。
MNISTデータスクリプト
前回、次のようなPythonコードを書きました。PytorchのDataLoaderについては、今後も多用すると思いますので、もうすこし掘り下げて確認しておきます。
import torch;
import torchvision
import matplotlib.pyplot as plt
import numpy
from torch.utils.data import DataLoader
train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True)
train_data = DataLoader(train_set,batch_size=None,shuffle=False)
data_iter = iter(train_data)
imgs, labels=next(data_iter)
plt.imshow(imgs, cmap='gray')
print("label numer is {}".format(labels))
DataLoader
MNISTデータセットは、60000個のデータの集まりです。機械学習では、データセットを小分けにするのをバッチ化と呼ぶようです。例えば、10個のデータを6000個といったイメージです。
Dataset MNIST Number of datapoints: 60000 Root location: ./data Split: Train
引数のbatch_sizeは、小分けの数を指定できるようです。batch_size=Noneの場合は小分けを作成しません。
次に小分けを作成してみます。
train_data = DataLoader(train_set,batch_size=10,shuffle=False)
batchを作成した場合、next()のところでエラーが発生してしまいました。
TypeError Traceback (most recent call last) <ipython-input-3-e1f9cb968349> in <module>() 11 data_iter = iter(train_data) 12 #first data ---> 13 imgs, labels=next(data_iter) 14 15 fig = plt.figure()
色々しらべてみると、データセットのダウンロードに問題があるようです。成功例を調べてみるとtransforms引数を設定していることがわかりました。
train_set = torchvision.datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
修正後のコードになります。
import torch;
import torchvision
from torchvision import transforms as transforms
import matplotlib.pyplot as plt
import numpy
from torch.utils.data import DataLoader
#train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True)
train_set = torchvision.datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
print(train_set)
print('----------------------------')
train_data = DataLoader(train_set,batch_size=32,shuffle=False)
tmp = train_data.__iter__()
#first data
imgs, labels = tmp.next()
print(imgs.size())
print(labels)
32個の画像(28 pixel x 28 pixel)が取得できるようになりました。
Dataset MNIST Number of datapoints: 60000 Root location: ./data Split: Train StandardTransform Transform: ToTensor() ---------------------------- torch.Size([32, 1, 28, 28]) tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1, 1, 2, 4, 3, 2, 7, 3, 8])
先頭4つのデータを表示してみました。
fig = plt.figure()
imgplot = 1
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[0]),fontsize=20)
plt.imshow(imgs[0][0], cmap='gray')
imgplot = 2
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[1]),fontsize=20)
plt.imshow(imgs[1][0], cmap='gray')
imgplot = 3
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[2]),fontsize=20)
plt.imshow(imgs[2][0], cmap='gray')
imgplot = 4
ax1 = fig.add_subplot(2, 2, imgplot)
ax1.set_title("{}".format(labels[3]),fontsize=20)
plt.imshow(imgs[3][0], cmap='gray')
まとめ
今回、pytorchのDataLoaderについて、バッチ化でデータロードをしてみました。エラーになやまされましたたが、”データダウンロード時にtransform=transforms.ToTensor()”を指定することでエラーが解消できました。
組み込み系ソフトエンジニアをしています。これまでフロントエンド技術は避けてきましたが、食わず嫌いをやめて、勉強を始めました。
趣味は、水泳、ロードバイク、ランニング、登山です。
組み込み系技術ネタ、勉強したフロントエンド技術、たまに趣味の運動について発信していきます。
どうぞよろしくお願いします。
コメント