import os
import numpy as np
import pdb
import torch
from torchvision import datasets
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms

from . import cub
from . import tiny_imagenet
from . import imagenet32

class DatasetWrapper(Dataset): 
    def __init__(self, dataset: Dataset) -> None:
        super().__init__()
        self.dataset = dataset

    def __getitem__(self, index): 
        item = self.dataset[index]
        if not isinstance(item, tuple): 
            item = (item,)
        return item + (index,)
    
    def __len__(self):
        return len(self.dataset)

def get_dataset(name, train_transform=None, test_transform=None, download=False):
    path = os.path.dirname(os.path.realpath(__file__))
    if name == 'MNIST':
        return get_MNIST(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'FASHIONMNIST':
        return get_FashionMNIST(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'SVHN':
        return get_SVHN(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'CIFAR10':
        return get_CIFAR10(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'CIFAR100':
        return get_CIFAR100(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'CUB': 
        return get_CUB(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'IMAGENET': 
        return get_ImageNet(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'TINYIMAGENET': 
        return get_TinyImageNet(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'IMAGENET32':
        return get_ImageNet32(path, train_transform=train_transform, test_transform=test_transform, download=download)
    elif name == 'IMAGENET64':
        return get_ImageNet64(path, train_transform=train_transform, test_transform=test_transform, download=download)

def get_MNIST(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'MNIST')
    train = datasets.MNIST(root, train=True, transform=train_transform, download=download)
    train_test = datasets.MNIST(root, train=True, transform=test_transform, download=download)
    test = datasets.MNIST(root, train=False, transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_FashionMNIST(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'FashionMNIST')
    train = datasets.FashionMNIST(root, train=True, transform=train_transform, download=download)
    train_test = datasets.FashionMNIST(root, train=True, transform=test_transform, download=download)
    test = datasets.FashionMNIST(root, train=False, transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_SVHN(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'SVHN')
    train = datasets.SVHN(root, split='train', transform=train_transform, download=download)
    train_test = datasets.SVHN(root, split='train', transform=test_transform, download=download)
    test = datasets.SVHN(root, split='test', transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_CIFAR10(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'CIFAR10')
    train = datasets.CIFAR10(root, train=True, transform=train_transform, download=download)
    train_test = datasets.CIFAR10(root, train=True, transform=test_transform, download=download)
    test = datasets.CIFAR10(root, train=False, transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_CIFAR100(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'CIFAR100')
    train = datasets.CIFAR100(root, train=True, transform=train_transform, download=download)
    train_test = datasets.CIFAR100(root, train=True, transform=test_transform, download=download)
    test = datasets.CIFAR100(root, train=False, transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_CUB(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'CUB')
    train = cub.Cub2011(root, train=True, transform=train_transform, download=download)
    train_test = cub.Cub2011(root, train=True, transform=test_transform, download=download)
    test = cub.Cub2011(root, train=False, transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_ImageNet(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'ImageNet')
    train = datasets.ImageNet(root, split='train', transform=train_transform, download=download)
    train_test = datasets.ImageNet(root, split='train', transform=test_transform, download=download)
    test = datasets.ImageNet(root, split='val', transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_TinyImageNet(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'TinyImageNet')
    train = tiny_imagenet.TinyImageNet(root, split='train', transform=train_transform, download=download)
    train_test = tiny_imagenet.TinyImageNet(root, split='train', transform=test_transform, download=download)
    test = tiny_imagenet.TinyImageNet(root, split='val', transform=test_transform, download=download)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_ImageNet32(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'ImageNet32')
    train = imagenet32.ImageNet32(root, train=True, transform=train_transform)
    train_test = imagenet32.ImageNet32(root, train=True, transform=test_transform)
    test = imagenet32.ImageNet32(root, train=False, transform=test_transform)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)

def get_ImageNet64(path, train_transform=None, test_transform=None, download=False):
    root = os.path.join(path, 'ImageNet64')
    train = imagenet32.ImageNet32(root, train=True, transform=train_transform)
    train_test = imagenet32.ImageNet32(root, train=True, transform=test_transform)
    test = imagenet32.ImageNet32(root, train=False, transform=test_transform)
    return DatasetWrapper(train), DatasetWrapper(train_test), DatasetWrapper(test)
