import numpy as np
from pathlib import Path
import os
import torch
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from timm.data import create_dataset
import torchvision

_DATASET_CFG = {
    'mnist': {
        'n_classes': 10,
    },
    'imagenet': {
        'n_classes': 1000
    },
    'cifar10': {
        'n_classes': 10
    },
    'svhn': {
        'n_classes': 10
    }
}


def get_loaders(dataset, batch_size, root=None, train_subset=1, num_workers=0, **kwargs):
    datasets = get_dataset(dataset=dataset, root=root, train_subset=train_subset,
                           **kwargs)  # (train_data, val_data, test_data)
    loaders = []
    for i, data in enumerate(datasets):
        loader = DataLoader(data, batch_size=batch_size, num_workers=num_workers, shuffle=(i == 0), pin_memory=True)
        loaders.append(loader)
    return loaders


def get_dataset(dataset, root=None, train_subset=1, **kwargs):
    root = get_data_dir(data_dir=root)

    if dataset.lower() == 'mnist':
        train_data, val_data, test_data = get_mnist(root=root, **kwargs)
    elif dataset.lower() == 'cifar10':
        train_data, val_data, test_data = get_cifar10(root=root, **kwargs)
    elif dataset.lower() == 'svhn':
        train_data, val_data, test_data = get_svhn(root=root, **kwargs)
    else:
        raise NotImplementedError("Dataset name not found")

    if np.abs(train_subset) < 1:
        train_n = len(train_data)
        ns = int(train_n * np.abs(train_subset))

        randperm = torch.randperm(train_n)
        assert len(randperm) == train_n, f'Permutation length {len(randperm)} does not match dataset length {train_n}'

        randperm = randperm[ns:] if train_subset < 0 else randperm[:ns]
        train_data = Subset(train_data, randperm)

    num_classes = _DATASET_CFG[dataset]['n_classes']
    setattr(train_data, 'n_classes', num_classes)
    return train_data, val_data, test_data


def get_data_dir(data_dir=None):
    if data_dir is None:
        if os.environ.get('DATADIR') is not None:
            data_dir = os.environ.get('DATADIR')
        else:
            home_data_dir = Path().home() / 'datasets'
            data_dir = str(home_data_dir.resolve())

    Path(data_dir).mkdir(parents=True, exist_ok=True)

    return data_dir


def get_mnist(root=None, extra_transform=None, **_):
    datasets = []
    for split in ['train', 'val', 'test']:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, )),
        ])
        if extra_transform is not None:
            transform = transforms.Compose([extra_transform, transform])
        datasets.append(create_dataset('torch/mnist', root=root, split=split, transform=transform, download=True))
    return datasets


def get_cifar10(root=None, extra_transform=None, **_):
    # auto_aug = torchvision.transforms.AutoAugment(torchvision.transforms.AutoAugmentPolicy.CIFAR10)
    # rand_aug = torchvision.transforms.RandAugment()
    transform_train = transforms.Compose([
        # auto_aug,
        # rand_aug,
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.49139968, 0.48215827, 0.44653124), (0.24703233, 0.24348505, 0.26158768)),
    ])

    if extra_transform is not None:
        transform_train = transforms.Compose([extra_transform, transform_train])
        transform_test = transforms.Compose([extra_transform, transform_test])

    train_data = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=transform_train)
    test_data = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=transform_test)
    val_data = test_data

    return train_data, val_data, test_data


def get_svhn(root=None, extra_transform=None, **_):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    if extra_transform is not None:
        transform = transforms.Compose([extra_transform, transform])

    train_data = torchvision.datasets.SVHN(root=root, split='train', download=True, transform=transform)
    test_data = torchvision.datasets.SVHN(root=root, split='test', download=True, transform=transform)
    val_data = test_data

    return train_data, val_data, test_data
