はじめに
こんにちは、swim-loverです。 Pythonで機械学習の一つである手書き数字認識を実装しています。Pythonを始めたばかりですが、「使いながら覚える」をコンセプトに勉強しています。 第6回は、 定番と思われる、手書き数字認識に着手し、Pytorchを使ってMNISTデータの準備をしました 。今回も続編を進めていきたいと思います。
MNISTデータスクリプト
前回、次のようなPythonコードを書きました。初めて使うPythonの記述もありましたので、確認しておきます。
import torch;
import torchvision
import matplotlib.pyplot as plt
import numpy
from torch.utils.data import DataLoader
train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True)
train_data = DataLoader(train_set,batch_size=None,shuffle=False)
data_iter = iter(train_data)
imgs, labels=next(data_iter)
plt.imshow(imgs, cmap='gray')
print("label numer is {}".format(labels))
from xx import yy
- from torch.utils.data import DataLoader を次のように記載しても同じ動作になります。
import torch.utils.data as util
train_data = util.DataLoader(train_set,batch_size=None,shuffle=False)
この場合は、torch.utils.dataモジュール全体をimportしています。(全体のimpoirt)、一方前者(from torch.utils.data import DataLoader)は、torch.utils.dataの中から、DataLoaderのみをimportしています。(部分のimport)モジュールサイズが巨大である場合、部分importのほうがメモリ消費量を少なくるすることができるのかもしれません。今は、このレベルの理解にとどめたいと思います。
iter(),next()
Python標準の組み込み関数の一つで、イテレーターと呼ぶらしいです。配列要素に対して、安全に取り出すことができるのは非常にありがたい機能です。C言語の場合、配列外アクセスなどは定番バグですがこれを排除できるのは、高級言語を使う恩恵だと思います。
array = [1, 2, 3]
a = iter(array)
val=next(a);print(val)
val=next(a);print(val)
val=next(a);print(val)
next()によって値が取り出せました。
1 2 3
再度、next()を実行してみます。どうなるでしょうか?
array = [1, 2, 3]
a = iter(array)
val=next(a);print(val)
val=next(a);print(val)
val=next(a);print(val)
val=next(a);print(val)
3の次は先頭の1が取り出せるのかと思いましたが、Stoplterationというエラーが発生しました。
1 2 3 --------------------------------------------------------------------------- StopIteration Traceback (most recent call last) <ipython-input-7-732f3ac04336> in <module>() 5 val=next(a);print(val) 6 val=next(a);print(val) ----> 7 val=next(a);print(val) StopIteration:
本格的なプログラミングでは、このようなエラーに対する処理を実装しておく必要があります。try/exceptionを追加することで、エラーを捉えることができます。
array = [1, 2, 3]
a = iter(array)
try:
val=next(a);print(val)
val=next(a);print(val)
val=next(a);print(val)
val=next(a);print(val)
except StopIteration:
print('end of array')
結果を確認します。エラーを正しく捉えることができました。
1 2 3 end of array
MINSTデータ複数個表示
next()を使って、MNISTデータを複数個表示してみました。
import torch;
import torchvision
import matplotlib.pyplot as plt
import numpy
from torch.utils.data import DataLoader
train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True)
train_data = DataLoader(train_set,batch_size=None,shuffle=False)
data_iter = iter(train_data)
#first data
imgs, labels=next(data_iter)
fig = plt.figure()
imgplot = 1
ax1 = fig.add_subplot(2, 2, imgplot)
plt.imshow(imgs, cmap='gray')
ax1.set_title("{}".format(labels),fontsize=20)
plt.imshow(imgs, cmap='gray')
#2nd data
imgs, labels=next(data_iter)
imgplot = 2
ax1 = fig.add_subplot(2, 2, imgplot)
plt.imshow(imgs, cmap='gray')
ax1.set_title("{}".format(labels),fontsize=20)
plt.imshow(imgs, cmap='gray')
#3rd data
imgs, labels=next(data_iter)
imgplot = 3
ax1 = fig.add_subplot(2, 2, imgplot)
plt.imshow(imgs, cmap='gray')
ax1.set_title("{}".format(labels),fontsize=20)
plt.imshow(imgs, cmap='gray')
#4th data
imgs, labels=next(data_iter)
imgplot = 4
ax1 = fig.add_subplot(2, 2, imgplot)
plt.imshow(imgs, cmap='gray')
ax1.set_title("{}".format(labels),fontsize=20)
plt.imshow(imgs, cmap='gray')
まとめ
今回、手書き数字識別のためにMNISTデータの準備で登場した、from xx import yy, iter(),next()のPythonの機能について書いてみました。次回は、PytorchのDataLoaderという関数について掘り下げてたいと思います。
組み込み系ソフトエンジニアをしています。これまでフロントエンド技術は避けてきましたが、食わず嫌いをやめて、勉強を始めました。
趣味は、水泳、ロードバイク、ランニング、登山です。
組み込み系技術ネタ、勉強したフロントエンド技術、たまに趣味の運動について発信していきます。
どうぞよろしくお願いします。
コメント