PyTorch, DataSetとDataLoaderを扱う

PyTorch ディープラーニングのライブラリを使うとき、Dataset をどのように用意するかが成功のカギであると思います。MNIST や CIFAR-10 などの場合は簡単に Dataset を作成できるようになっています。

Python で画像データを読みこむところは初心者をハマらせてしまいやすいです。ここでイヤになってしまうとライブラリを使ってもらえないので、つよくラップ(Wrap)された簡単メソッドが用意されているのは当然のことといえましょう。

ですが、このせいで MNIST と CIFAR-10 の画像以外も試してみようと思ったときに、Dataset と DataLoader の扱いができなくて、そこで応用学習が停滞してしまうということが考えられます。

この記事では Dataset と DataLoader を扱う最小限のコードを紹介します。Google Colab 上で動作を確認しました。PyTorch のバージョンは1.12.0+cu113 です。

あとあとオートエンコーダに使うことを想定して、

/content/drive/MyDrive/Colab Notebooks/imgs/training というディレクトリに95個のJPEGファイルを配置してください。拡張子は小文字で .jpg としてください。

/content/drive/MyDrive/Colab Notebooks/imgs というディレクトリに5個の推論画像に使うJPEGファイルを配置してください。拡張子は小文字で .jpg としてください。

このディレクトリというのは GoogleDrive サービスのドライブ内のことをイメージしておりますが、自身のパソコンに画像を用意するのならば、画像の格納パスを適宜かきかえてください。

上記の 95 + 5 の合計100個の画像のサイズは幅256ピクセル、高さ256ピクセルとしてください。もし違う幅高サイズの画像しかなければ、48行目のWと49行目のHを、その画像に合わせるように変更してください。

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

import glob
from PIL import Image
import matplotlib.pyplot as plt

#-------------------------------------------------------------------------------
# Dataset を継承したクラスを定義する.
class MyDataSet( Dataset ):
    def __init__(self, img_dir, transform=None):
        # 画像ファイルのパス一覧を取得する.
        search_pattern = img_dir + "/" + "*.jpg"
        list_filepath_src = glob.glob( search_pattern, recursive=False )

        # ファイルパスリストをソートする.
        list_filepath_src.sort()

        # ソート済みのファイルパスリストをメンバに仕込む.
        self.img_paths = list_filepath_src
        self.transform = transform

    def __getitem__( self, index ):
        # 所望のファイルパスから画像を読み込む.
        fp = self.img_paths[index]
        im = Image.open( fp )

        # 前処理がある場合は行う。
        if self.transform is not None:
            im = self.transform( im )

        # 画像を返す.
        return im

    def __len__(self):
        # ファイルパスリストの個数を返す.
        # これはディレクトリ内の画像の個数と同じ.
        return len( self.img_paths )

#-------------------------------------------------------------------------------

# pytorch のバージョンを表示する.
print( torch.__version__ )
print( "" )

# Transform を作成する.
W = 256 # 基本的に画像サイズ幅、この値を変えてみると Transform の効き目がわかる.
H = 256 # 基本的に画像サイズ高、この値を変えてみると Transform の効き目がわかる.
trnsfm = transforms.Compose([transforms.Resize(( W, H )), transforms.ToTensor()])

# トレーニング用のデータセットを作成する.
the_dir_trn = "/content/drive/MyDrive/Colab Notebooks/imgs/training"
dts_trn = MyDataSet( the_dir_trn, trnsfm )
print( "dts_trn length {0}.".format( len( dts_trn )))

# テスト用のデータセットを作成する.
the_dir_tst = "/content/drive/MyDrive/Colab Notebooks/imgs"
dts_tst = MyDataSet( the_dir_tst, trnsfm )
print( "dts_tst length {0}.".format( len( dts_tst )))

# 表示がわかりにくいから空行を入れる.
print( "" )

# トレーニング用とテスト用のデータローダを作成する.
# バッチサイズを変えるとデータの塊の個数が変わる.
dataloader_trn = DataLoader( dts_trn, batch_size = 10 )
dataloader_tst = DataLoader( dts_tst, batch_size =  1 )

# トレーニング用のデータローダの素性を表示する.
counter = 0
for batch in dataloader_trn:
    print( batch.shape )
    counter += 1
print( counter )
print( "" )

# テスト用のデータローダの素性を表示する.
counter = 0
for batch in dataloader_tst:
    print( batch.shape )
    counter += 1
print( counter )
print( "" )

# テスト用のデータをグラフ表示するためのリストを宣言してループで仕込む.
list_tmp = list()
for n in range( len( dts_tst )):
    list_tmp.append( dts_tst[n][0] )

# dts_tst[0][0]
# dts_tst[1][0]
# dts_tst[2][0]
# dts_tst[3][0]
# dts_tst[4][0]

# 書き込み要素を取得する,表示サイズは dpi で調整せよ.
fig = plt.figure( dpi=480 )

# タイルの縦横の個数.
TILE_NUM_ROW = 1
TILE_NUM_COL = 5

# プロットの位置を示すカウンタ(1スタートなのが気持ち悪い).
plot_counter = 1

# リストの要素回数だけループする.
for tmp in list_tmp:
    ax = fig.add_subplot( TILE_NUM_ROW, TILE_NUM_COL, plot_counter )
    ax.set_title( "[{0}]".format( plot_counter - 1 ))
    ax.axis( "off" )
    plt.imshow( tmp, cmap="gray" )
    plot_counter += 1

plt.show()

# 成功終了表示.
print( "finish." )

下記が実行結果です。matplotlib でオートエンコーダ推論にかける5個の画像も表示されるはずです。

1.12.0+cu113

dts_trn length 95.
dts_tst length 5.

torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([10, 1, 256, 256])
torch.Size([5, 1, 256, 256])
10

torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
5