import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from . import femnist

class CustomDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data_point = self.data[idx]
        return data_point
    

class TorchVisionDataset:
    def __init__(self):
        self.root_dir = './data/'
        self.data_fn = None
        self.train_transform = None
        self.test_transform = None

    def prepare_data(self):
        self.train_set = self.data_fn(root=self.root_dir, 
                                      train=True,
                                      download=True, 
                                      transform=self.train_transform) 
        self.test_set = self.data_fn(root=self.root_dir,
                                     train=False,
                                     download=True,
                                     transform=self.test_transform)


class MNIST(TorchVisionDataset):
    def __init__(self):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        self.train_transform = self.test_transform = self.transform
        self.data_fn = torchvision.datasets.MNIST
        self.prepare_data()

class FashionMNIST(TorchVisionDataset):
    def __init__(self):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        self.train_transform = self.test_transform = self.transform
        self.data_fn = torchvision.datasets.FashionMNIST
        self.prepare_data()

class CIFAR10(TorchVisionDataset):
    def __init__(self):
        super().__init__()
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.data_fn = torchvision.datasets.CIFAR10
        self.prepare_data()

class CIFAR100(TorchVisionDataset):
    def __init__(self):
        super().__init__()
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.data_fn = torchvision.datasets.CIFAR100
        self.prepare_data()
