import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        data_point = self.data[idx]
        return data_point
    

class TorchVisionDataset:
    def __init__(self, root_dir=None):
        self.root_dir = '../data/' if root_dir is None else root_dir
        self.data_fn = None
        self.train_transform = None
        self.test_transform = None

    def prepare_data(self):
        self.train_set = self.data_fn(root=self.root_dir, 
                                      train=True,
                                      download=True, 
                                      transform=self.train_transform) 
        self.test_set = self.data_fn(root=self.root_dir,
                                     train=False,
                                     download=True,
                                     transform=self.test_transform)


class MNIST(TorchVisionDataset):
    def __init__(self, root_dir=None):
        super().__init__(root_dir)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
        self.train_transform = self.test_transform = self.transform
        self.data_fn = torchvision.datasets.MNIST
        self.prepare_data()

class FashionMNIST(TorchVisionDataset):
    def __init__(self, root_dir=None):
        super().__init__(root_dir)
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ])
        self.train_transform = self.test_transform = self.transform
        self.data_fn = torchvision.datasets.FashionMNIST
        self.prepare_data()

class CIFAR10(TorchVisionDataset):
    def __init__(self, root_dir=None):
        super().__init__(root_dir)
        self.train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        self.data_fn = torchvision.datasets.CIFAR10
        self.prepare_data()

# class CIFAR100(TorchVisionDataset):
#     def __init__(self, root_dir=None):
#         super().__init__(root_dir)
#         self.train_transform = transforms.Compose([
#             transforms.RandomCrop(32, padding=4),
#             transforms.RandomHorizontalFlip(),
#             transforms.ToTensor(),
#         ])
#         self.test_transform = transforms.Compose([
#             transforms.ToTensor(),
#         ])
#         self.data_fn = torchvision.datasets.CIFAR100
#         self.prepare_data()

class CIFAR100(TorchVisionDataset):
    def __init__(self, root_dir=None):
        super().__init__(root_dir)
        self.train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
        self.test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010))
        ])
        self.data_fn = torchvision.datasets.CIFAR100
        self.prepare_data()
