Pythonで機械学習 手書き数字認識 MNISTデータの準備 Part(6)

colaboratory

はじめに

こんにちは、swim-loverです。 Pythonで機械学習の一つである手書き数字認識を実装しています。Pythonを始めたばかりですが、「使いながら覚える」をコンセプトに勉強しています。 第5回は、 numpyで行列演算を行い、 3層ニューラルネットを作って みました。 といっても実際に何かを識別するようなことはしていません。今回は、定番と思われる、手書き数字認識を行ってみたいと思います。

参考書籍

今回、機械学習の参考書籍として、”斎藤康毅著 ゼロから作るDeep Learning オライリージャパン 2016年9月”を使用しました。 

機械学習フレームワーク

手書き数字の識別するためには、MNISTデータを使用します。MNISTデータは、大変有名な手書き数字のデータセットです。参考書籍では、書籍で提供するPythonスクリプトでWEBサイトから取得できるようになっています。この投稿では、別のアプローチを考えてみました。調べて見ると、機械学習(深層学習)のフレームワークとして、”tensorflow”と”Pytorch”の2つがよく使われているようです。折角なら、これらのフレームワークにも触れならが勉強できれば、きっと役に立つはずです。そこで、”Pytorch”を使って手書き数字認識を作って見たいと思います。

Pytorch version

Python実行環境として使っているcolabolatoryには、Pytorchが予めインストールされていました。念のためVersionを確認してみました。インストールエラーなどのトラブルを回避できるので便利だと思います。

import torch;
torch.__version__

ver 1.10.0がインストールされているようです。

1.10.0+cu111

MNISTデータセットの準備

Pytorch本体とコンピュータビジョン用のライブラリtorchvisionを使用します。 torchvision ライブラリの中にMNISTをダウンロードする機能が用意されていました。

import torch;
import torchvision

train_set = torchvision.datasets.MNIST(root='./data',train=True,download=True)

MNISTのあるWebサイトからダウンロードされています。

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
9913344/? [00:00<00:00, 49601900.51it/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

(省略)

colabではいつくかのlinux コマンドが使えるようですので、Downlaodされたものを確認してみました。pwd,cd,lsコマンドを使用して、以下の様なデータがダウンロードされていることが確認できました。

pwd
/content/data/MNIST/raw
ls
t10k-images-idx3-ubyte     train-images-idx3-ubyte
t10k-images-idx3-ubyte.gz  train-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte     train-labels-idx1-ubyte
t10k-labels-idx1-ubyte.gz  train-labels-idx1-ubyte.gz

MNISTデータセットの表示

次にMNISTデータセットの画像を表示してみます。pytorchにDataLoaderという機能が準備されています。

import torch
import torchvision
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

train_data = DataLoader(train_set,batch_size=None,shuffle=False)
data_iter = iter(train_data)
imgs, labels=next(data_iter)
plt.imshow(imgs, cmap='gray')
print("label numer is {}".format(labels))

画像が表示されました。手書き数字は5でした。

まとめ

今回、手書き数字識別のためにMNISTデータの準備をしました。MNISTデータセットは、深層学習のフレームワークのPytorchを使って、ダウンロードしてみました。DataLoaderという関数(モジュール)を使ってみましたが、今後も多用すると思いますのでもう少し掘り下げたほうがよさそうです。これは次回の記事で触れたいと思いおます。また、iter(),next()というPythonの機能も新しく出てきました。これも次回の記事で触れたいと思います。

コメント

タイトルとURLをコピーしました