* File: __init__.py
* Created on: Tue Sep 14 20:39:57 2021
* Author: Liam Woolley
* Licence: MIT
"""
import torch
from torch.utils.data import Dataset, DataLoader
class MNISTDataset(Dataset):
def __init__(self, image_dir='./images', label_dir='./labels'):
self.image_dir = image_dir
self.label_dir = label_dir
super().__init__()
def __getitem__(self, index):
img = torch.load(f'{self.image_dir}/{index}.jpg')
return (img, torch.tensor(int(f'{self.label_dir}/{index}.txt'))), None
def __len__(self):
return len(os.listdir(self.image_dir))
class MNISTDataLoader(DataLoader):
def __init__(self, dataset, batch_size=32, shuffle=True, num_workers=4):
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
self.dataset = dataset
def __getitem__(self, index):
img, label = self.dataset.__getitem__(index)
return torch.tensor(img), label
def load_data():
ds = MNISTDataset()
dl = MNISTDataLoader(ds)
return ds, dl