import os
import numpy as np

import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import Subset


__all__ = [
    "DATASET_NUM_CLASSES",
    "DATASET_REGISTRY",
    "get_datasets"
]


def grayscale_to_rgb(x):
    if x.shape[0] == 1:
        return x.repeat(3, 1, 1)
    return x


def get_transforms(
    mean: tuple, 
    std: tuple, 
    train_crop_size: int, 
    test_resize_size: int, 
    test_crop_size: int,
    augment: bool = False,
):
    test_transforms = T.Compose([
        T.Resize(test_resize_size),
        T.CenterCrop(test_crop_size),
        T.ToTensor(),
        T.Lambda(grayscale_to_rgb),
        T.Normalize(mean, std)
    ])
    if augment:
        train_transforms =  T.Compose([
            T.RandomResizedCrop(train_crop_size),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Lambda(grayscale_to_rgb),
            T.Normalize(mean, std)
        ])
    else:
        train_transforms = test_transforms
    
    return train_transforms, test_transforms


def get_cifar10(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    root = os.path.join(data_dir, "CIFAR10")
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    train_dataset = datasets.CIFAR10(root, train=True, transform=train_transforms, download=download)
    test_dataset = datasets.CIFAR10(root, train=False, transform=test_transforms, download=download)
    return train_dataset, test_dataset


def get_cifar100(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    root = os.path.join(data_dir, "CIFAR100")
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    train_dataset = datasets.CIFAR100(root, train=True, transform=train_transforms, download=download)
    test_dataset = datasets.CIFAR100(root, train=False, transform=test_transforms, download=download)
    return train_dataset, test_dataset


def get_cifar100(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    root = os.path.join(data_dir, "CIFAR100")
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2023, 0.1994, 0.2010)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    train_dataset = datasets.CIFAR100(root, train=True, transform=train_transforms, download=download)
    test_dataset = datasets.CIFAR100(root, train=False, transform=test_transforms, download=download)
    return train_dataset, test_dataset


def get_caltech101(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    NUM_TRAINING_SAMPLES_PER_CLASS = 30
    root = os.path.join(data_dir)
    mean = (0.5413, 0.5063, 0.4693)
    std = (0.3115, 0.3090, 0.3183)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    # get whole dataset
    dataset = datasets.Caltech101(root, download=download)
    class_start_idx = [0]+ [i for i in np.arange(1, len(dataset)) if dataset.y[i] == dataset.y[i-1]+1]

    train_indices = sum([
        np.arange(start_idx,start_idx + NUM_TRAINING_SAMPLES_PER_CLASS).tolist() 
        for start_idx in class_start_idx], []
    )
    test_indices = list((set(np.arange(1, len(dataset))) - set(train_indices)))

    train_dataset = Subset(dataset, train_indices)
    test_dataset = Subset(dataset, test_indices)

    train_dataset.dataset.transform = train_transforms
    test_dataset.dataset.transform = test_transforms
    return train_dataset, test_dataset


def get_caltech256(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    NUM_TRAINING_SAMPLES_PER_CLASS = 60
    root = os.path.join(data_dir)
    mean = (0.5438, 0.5141, 0.4821)
    std = (0.3077, 0.3044, 0.3163)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    # get whole dataset
    dataset = datasets.Caltech256(root, download=download)
    class_start_idx = [0] + [i for i in np.arange(1, len(dataset)) if dataset.y[i] == dataset.y[i-1]+1]

    train_indices = sum([
        np.arange(start_idx,start_idx + NUM_TRAINING_SAMPLES_PER_CLASS).tolist() 
        for start_idx in class_start_idx], []
    )
    test_indices = list((set(np.arange(1, len(dataset))) - set(train_indices)))

    train_dataset = Subset(dataset, train_indices)
    test_dataset = Subset(dataset, test_indices)

    train_dataset.dataset.transform = train_transforms
    test_dataset.dataset.transform = test_transforms
    return train_dataset, test_dataset


def get_fgvc_aircraft(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    mean = (0.4892, 0.5159, 0.5356)
    std = (0.2275, 0.2200, 0.2476)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    train_dataset = datasets.FGVCAircraft(root=data_dir, split='trainval', download=download, transform=train_transforms)
    test_dataset = datasets.FGVCAircraft(root=data_dir, split='test', download=download, transform=test_transforms)
    return train_dataset, test_dataset


def get_flowers102(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    mean = (0.5133, 0.4148, 0.3383)
    std = (0.2959, 0.2502, 0.2900)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    train_dataset = datasets.Flowers102(root=data_dir, split='train', download=download, transform=train_transforms)
    test_dataset = datasets.Flowers102(root=data_dir, split='test', download=download, transform=test_transforms)
    return train_dataset, test_dataset


def get_stanford_cars(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    mean = (0.4513, 0.4354, 0.4358)
    std = (0.2900, 0.2880, 0.2951)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    train_dataset = datasets.StanfordCars(root=data_dir, split='train', download=download, transform=train_transforms)
    test_dataset = datasets.StanfordCars(root=data_dir, split='test', download=download, transform=test_transforms)
    return train_dataset, test_dataset


def get_oxford_pets(
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    mean = (0.4807, 0.4432, 0.3949)
    std = (0.2598, 0.2537, 0.2597)
    train_transforms, test_transforms = get_transforms(
        mean, std, train_crop_size , test_resize_size , test_crop_size, augment
    )
    train_dataset = datasets.OxfordIIITPet(root=data_dir, split='trainval', target_types='category', download=download, transform=train_transforms)
    test_dataset = datasets.OxfordIIITPet(root=data_dir, split='test', target_types='category', download=download, transform=test_transforms)
    return train_dataset, test_dataset


DATASET_REGISTRY = {
    'cifar10': get_cifar10,
    'cifar100': get_cifar100,
    'caltech101': get_caltech101,
    'caltech256': get_caltech256,
    'fgvc_aircraft': get_fgvc_aircraft,
    'stanford_cars': get_stanford_cars,
    'flowers102': get_flowers102,
    'oxford_pets': get_oxford_pets
}


DATASET_NUM_CLASSES = {
    'cifar10': 10,
    'cifar100': 100,
    'caltech101': 101,
    'caltech256': 257,
    'fgvc_aircraft': 100,
    'stanford_cars': 196,
    'flowers102': 102,
    'oxford_pets': 37
}


def get_datasets(
    dataset: str,
    data_dir: str, 
    train_crop_size: int,
    test_resize_size: int, 
    test_crop_size: int, 
    download: bool = False,
    augment: bool = False,
):
    return DATASET_REGISTRY[dataset](
        data_dir=data_dir,
        train_crop_size=train_crop_size,
        test_resize_size=test_resize_size,
        test_crop_size=test_crop_size,
        download=download,
        augment=augment
    )
