環境
- python 3.8
- CUDA 11.3
- pytorch 1.10
- torchvision
カスタムデータの読み込み
データセットがサブディレクトリ名をラベルとして下記のように保存されているとして
dataset
├─dog
│ ├─001.png
│ ├─002.png
│ ...
└─cat
├─001.png
├─002.png
...
ImageFolderを使ってデータを読み込む。OnlineなAugumentationをローダーに付与することもできる。
import torchvision
dataset = torchvision.datasets.ImageFolder(
'./dataset',
torchvision.transform.Compose([
torchvision.transform.RandomHorizontalFlip(),
torchvision.transform.ToTensor(),
torchvision.transform.Normalize(mean = [0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
)
このときラベルはdatset.class_to_idx
でdict型で{'label': 0, ...}
として取得できるのでclasses = list(dataset.class_to_idx.keys())
でidからラベルを得るリストを作れる。
データローダーへのセット
読み込んだデータはそのままDataLoaderへ…
import torch
data_train = torch.utils.data.DataLoader(dataset, batchsize = 64, shuffle = True)
読み込んだ画像がRGBのカラー画像だとして表示するときは
import matplotlib.pyplot as plt
sample_image, label = iter(data_train)
plt.imshow(sample_image[0].permute(1, 2, 0))
No comments:
Post a Comment