画像分類カスタムデータの読み込み (Pytorch 1.10)

環境

  • python 3.8
  • CUDA 11.3
  • pytorch 1.10
  • torchvision

カスタムデータの読み込み

データセットがサブディレクトリ名をラベルとして下記のように保存されているとして

dataset
  ├─dog
  │  ├─001.png
  │  ├─002.png
  │  ...
  └─cat
     ├─001.png
     ├─002.png
     ...

ImageFolderを使ってデータを読み込む。OnlineなAugumentationをローダーに付与することもできる。

import torchvision

dataset = torchvision.datasets.ImageFolder(
  './dataset',
  torchvision.transform.Compose([
    torchvision.transform.RandomHorizontalFlip(),
    torchvision.transform.ToTensor(),
    torchvision.transform.Normalize(mean = [0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  ])
)

このときラベルはdatset.class_to_idxでdict型で{'label': 0, ...}として取得できるのでclasses = list(dataset.class_to_idx.keys())でidからラベルを得るリストを作れる。

データローダーへのセット

読み込んだデータはそのままDataLoaderへ…

import torch
data_train = torch.utils.data.DataLoader(dataset, batchsize = 64, shuffle = True)

読み込んだ画像がRGBのカラー画像だとして表示するときは

import matplotlib.pyplot as plt
sample_image, label = iter(data_train)
plt.imshow(sample_image[0].permute(1, 2, 0))

参考

No comments:

Post a Comment