import os
import torch
import torchvision
from torch.utils.data import DataLoader, Subset
from torchvision import transforms as transforms

import sys
import pathlib

sys.path.insert(0, str(pathlib.Path().absolute().parent))

from godin.generate_loaders import GaussianLoader, UniformLoader
from godin.generate_loaders import Normalizer

r_mean = 125.3 / 255
g_mean = 123.0 / 255
b_mean = 113.9 / 255
r_std = 63.0 / 255
g_std = 62.1 / 255
b_std = 66.7 / 255
train_transform_cifar10 = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((r_mean, g_mean, b_mean), (r_std, g_std, b_std)),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_transform_cifar10 = transforms.Compose([
    transforms.CenterCrop((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((r_mean, g_mean, b_mean), (r_std, g_std, b_std)),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_transform_cifar100 = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
            (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
    ]
)
test_transform_cifar100 = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
            (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)),
    ])

generating_loaders_dict = {
    'Gaussian': GaussianLoader,
    'Uniform': UniformLoader
}


def get_dataset(data_ind='CIFAR10', data_dir='./models', batch_size=64,
                num_workers=4, data_ood='SVHN'):
    """
    Get in-distribution dataset.

    :param data_ind: name of the in-distribution dataset
    :param data_dir: path to the data files
    :param batch_size: batch size
    :param num_workers: number of workers on CPUs to fetch the data
    :param data_ood: name of the ood dataset
    :return: train and test loaders for the in-distribution data
    """
    train_loader_in, validation_loader_in, test_loader_in, outlier_loader, num_classes = get_datasets(
        data_dir=data_dir, data_ood=data_ood, batch_size=batch_size,
        num_workers=num_workers, data_ind=data_ind)
    return train_loader_in, test_loader_in


def get_dataset_train_val_test(data_ind='CIFAR10', data_dir='./models',
                               batch_size=64,
                               num_workers=4, data_ood='SVHN'):
    """
    Get in-distribution datasets: train, validation, and test.

    :param data_ind: name of the in-distribution dataset
    :param data_dir: path to the data files
    :param batch_size: batch size
    :param num_workers: number of workers on CPUs to fetch the data
    :param data_ood: name of the ood dataset
    :return: train and test loaders for the in-distribution data
    """
    train_loader_in, validation_loader_in, test_loader_in, outlier_loader, num_classes = get_datasets(
        data_dir=data_dir, data_ood=data_ood, batch_size=batch_size,
        num_workers=num_workers, data_ind=data_ind)
    return train_loader_in, validation_loader_in, test_loader_in


def get_datasets(data_dir, data_ood, batch_size, data_ind, num_workers):
    """
    Get in-distribution and out-of-distribution datasets.

    :param data_dir: path to the folder with data
    :param data_ood: name of the out-of-distribution data
    :param batch_size: batch size
    :param data_ind: name of the in-distribution data
    :param num_workers: number of workers on the cpu to fetch the data
    :return: train_loader_in, validation_loader_in, test_loader_in, outlier_loader, num_classes
    """
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    if data_ind == 'CIFAR10':
        num_classes = 10
        train_set_in = torchvision.datasets.CIFAR10(
            root=f'{data_dir}/cifar10',
            train=True, download=True,
            transform=train_transform_cifar10)
        test_set_in = torchvision.datasets.CIFAR10(
            root=f'{data_dir}/cifar10',
            train=False, download=True,
            transform=test_transform_cifar10)
    elif data_ind == 'CIFAR100':
        num_classes = 100
        train_set_in = torchvision.datasets.CIFAR10(
            root=f'{data_dir}/cifar100',
            train=True, download=True,
            transform=train_transform_cifar100)
        test_set_in = torchvision.datasets.CIFAR10(
            root=f'{data_dir}/cifar100',
            train=False, download=True,
            transform=test_transform_cifar100)
    elif data_ind == 'SVHN':
        num_classes = 10
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        )
        train_set_in = torchvision.datasets.SVHN(
            root=f'{data_dir}/{data_ood}', split="train", download=True,
            transform=transform
        )
        test_set_in = torchvision.datasets.SVHN(
            root=f'{data_dir}/{data_ood}', split="test", download=True,
            transform=transform
        )
    else:
        raise Exception(f"Unknown in-distribution dataset: {data_ind}.")

    if data_ood == 'Gaussian' or data_ood == 'Uniform':
        normalizer = Normalizer(r_mean, g_mean, b_mean, r_std, g_std, b_std)
        outlier_loader = generating_loaders_dict[data_ood](
            batch_size=batch_size, num_batches=int(10000 / batch_size),
            transformers=[normalizer])
    elif data_ood == 'Imagenet':
        outlier_set = torchvision.datasets.ImageFolder(
            f'{data_dir}/{data_ood}', transform=test_transform_cifar10)
        outlier_loader = DataLoader(outlier_set, batch_size=batch_size,
                                    shuffle=False, num_workers=num_workers)
    elif data_ood == 'SVHN':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
        )
        trainset = torchvision.datasets.SVHN(
            root=f'{data_dir}/{data_ood}', split="train", download=True,
            transform=transform
        )
        kwargs = {"num_workers": num_workers, "pin_memory": True}
        # train_loader = torch.utils.data.DataLoader(
        #     trainset, batch_size=batch_size, shuffle=True, **kwargs
        # )
        testset = torchvision.datasets.SVHN(
            root=f'{data_dir}/{data_ood}', split="test", download=True,
            transform=transform
        )
        test_loader = torch.utils.data.DataLoader(
            testset, batch_size=batch_size, shuffle=True, **kwargs
        )
        outlier_loader = test_loader
    else:
        raise Exception(f"Unsupported OOD dataset: {data_ood}.")

    test_indices = list(range(len(test_set_in)))
    validation_set_in = Subset(test_set_in, test_indices[:1000])
    test_set_in = Subset(test_set_in, test_indices[1000:])

    train_loader_in = DataLoader(train_set_in, batch_size=batch_size,
                                 shuffle=True, num_workers=num_workers)
    validation_loader_in = DataLoader(validation_set_in, batch_size=batch_size,
                                      shuffle=False, num_workers=num_workers)
    test_loader_in = DataLoader(test_set_in, batch_size=batch_size,
                                shuffle=False, num_workers=num_workers)

    return train_loader_in, validation_loader_in, test_loader_in, outlier_loader, num_classes
