* File: __init__.py * Created on: Tue Sep 14 20:39:57 2021 * Author: Liam Woolley * Licence: MIT """ import torch from torch.utils.data import Dataset, DataLoader class MNISTDataset(Dataset): def __init__(self, image_dir='./images', label_dir='./labels'): self.image_dir = image_dir self.label_dir = label_dir super().__init__() def __getitem__(self, index): img = torch.load(f'{self.image_dir}/{index}.jpg') return (img, torch.tensor(int(f'{self.label_dir}/{index}.txt'))), None def __len__(self): return len(os.listdir(self.image_dir)) class MNISTDataLoader(DataLoader): def __init__(self, dataset, batch_size=32, shuffle=True, num_workers=4): super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) self.dataset = dataset def __getitem__(self, index): img, label = self.dataset.__getitem__(index) return torch.tensor(img), label def load_data(): ds = MNISTDataset() dl = MNISTDataLoader(ds) return ds, dl