import os
import sys

import torch
import torchvision
import torchvision.transforms as transforms

sys.path.append('../digitclutter')
sys.path.append('digitclutter')
from digitclutter.digclutLoad import return_data_digclut

def sampler(train_data, train_eval_data, val_data, equal_sampling, n, trainset, testset,
            batch_size, num_workers=8):
    indices = list(range(n))
    if not equal_sampling:
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:trainset])
    else:
        train_sampler = []
        for i in range(10):
            nr = int(trainset/10)
            if i < trainset % 10:
                nr += 1
            new_indices = torch.where(train_data.targets==i)[0][:nr]
            train_sampler += new_indices
        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_sampler)
    val_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[testset:])
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                               sampler=train_sampler,
                                               num_workers=num_workers)
    train_eval_loader = torch.utils.data.DataLoader(
        train_eval_data, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers
    )
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size,
                                             sampler=val_sampler,
                                             num_workers=num_workers)
    return train_loader, train_eval_loader, val_loader

def get_mnist(trainset=50000, batch_size=128, testset=None,
              equal_sampling=False, normalize=True):
    if trainset is None:
        trainset = 50000
    if testset is None:
        testset = trainset
    else:
        testset = 60000-testset
    _train_transform = [
        transforms.RandomCrop(28, padding=4),
        transforms.ToTensor()
    ]
    _eval_transform = [
        transforms.ToTensor()
    ]
    if normalize:
        _train_transform.append(normalizer['mnist'])
        _eval_transform.append(normalizer['mnist'])
    train_transform = transforms.Compose(_train_transform)
    eval_transform = transforms.Compose(_eval_transform)
    train_data = torchvision.datasets.MNIST(root='./_data', train=True,
                                            transform=train_transform,
                                            download=True)
    train_eval_data = torchvision.datasets.MNIST(root='./_data', train=True,
                                                 transform=eval_transform,
                                                 download=True)
    val_data = torchvision.datasets.MNIST(root='./_data', train=True,
                                          transform=eval_transform,
                                          download=True)
    return sampler(
        train_data=train_data, train_eval_data=train_eval_data, val_data=val_data,
        equal_sampling=equal_sampling, n=60000,
        trainset=trainset, testset=testset, batch_size=batch_size
    )

normalizer = {
    'cifar10': transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
    'mnist': transforms.Normalize((0.5,), (0.5,)),
    'fashion_mnist': transforms.Normalize((0.5,), (0.5,))
}

def get_cifar10(trainset=40000, batch_size=128, testset=None,
                equal_sampling=False, num_workers=8, normalize=True):
    if trainset is None:
        trainset = 40000
    if testset is None:
        testset = trainset
    else:
        testset = 50000-testset
    _train_transform = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]
    _eval_transform = [
        transforms.ToTensor()
    ]
    if normalize:
        _train_transform.append(normalizer['cifar10'])
        _eval_transform.append(normalizer['cifar10'])
    train_transform = transforms.Compose(_train_transform)
    eval_transform = transforms.Compose(_eval_transform)
    train_data = torchvision.datasets.CIFAR10(root='./_data', train=True,
                                            transform=train_transform,
                                            download=True)
    train_eval_data = torchvision.datasets.CIFAR10(root='./_data', train=True,
                                                 transform=eval_transform,
                                                 download=True)
    val_data = torchvision.datasets.CIFAR10(root='./_data', train=True,
                                          transform=eval_transform,
                                          download=True)
    return sampler(
        train_data=train_data, train_eval_data=train_eval_data, val_data=val_data,
        equal_sampling=equal_sampling, n=50000,
        trainset=trainset, testset=testset, batch_size=batch_size,
        num_workers=num_workers
    )

def get_fashion_mnist(trainset=50000, batch_size=128, testset=None,
                      equal_sampling=False, normalize=True):
    if trainset is None:
        trainset = 50000
    if testset is None:
        testset = trainset
    else:
        testset = 60000-testset
    _train_transform = [
        transforms.RandomCrop(28, padding=4),
        transforms.ToTensor()
    ]
    _eval_transform = [
        transforms.ToTensor()
    ]
    if normalize:
        _train_transform.append(normalizer['fashion_mnist'])
        _eval_transform.append(normalizer['fashion_mnist'])
    train_transform = transforms.Compose(_train_transform)
    eval_transform = transforms.Compose(_eval_transform)
    train_data = torchvision.datasets.FashionMNIST(root='./_data', train=True,
                                            transform=train_transform,
                                            download=True)
    train_eval_data = torchvision.datasets.FashionMNIST(root='./_data', train=True,
                                                 transform=eval_transform,
                                                 download=True)
    val_data = torchvision.datasets.FashionMNIST(root='./_data', train=True,
                                          transform=eval_transform,
                                          download=True)
    return sampler(
        train_data=train_data, train_eval_data=train_eval_data, val_data=val_data,
        equal_sampling=equal_sampling, n=60000,
        trainset=trainset, testset=testset, batch_size=batch_size
    )

def get_digitclutter(digits, image_size=32, trainset=89999, batch_size=128, testset=None, equal_sampling=False, path='./_data',
                     normalize=True):
    if not normalize:
        raise NotImplementedError()
    if trainset is None:
        trainset = 89999
    if testset is None:
        testset = trainset
    else:
        testset = 99999-testset
    train_data, __ = return_data_digclut(os.path.join(path, 'digclut'+str(digits)), digits, image_size)
    return sampler(
        train_data=train_data, train_eval_data=train_data, val_data=train_data,
        equal_sampling=equal_sampling, n=99999,
        trainset=trainset, testset=testset, batch_size=batch_size
    )

def get_digitclutter_2(trainset=89999, batch_size=128, testset=None, equal_sampling=False, normalize=True, path='./_data'):
    return get_digitclutter(2, trainset=trainset, batch_size=batch_size,
                            testset=testset, equal_sampling=equal_sampling,
                            normalize=normalize, path=path)

def get_digitclutter_3(trainset=89999, batch_size=128, testset=None, equal_sampling=False, normalize=True, path='./_data'):
    return get_digitclutter(3, trainset=trainset, batch_size=batch_size,
                            testset=testset, equal_sampling=equal_sampling,
                            normalize=normalize, path=path)

data_dict = {
    'mnist': get_mnist,
    'cifar10': get_cifar10,
    'fashion_mnist': get_fashion_mnist,
    'digitclutter_3': get_digitclutter_3,
    'digitclutter_2': get_digitclutter_2
}

def get_data(string, trainset=None, batch_size=128, testset=None,
             equal_sampling=False, normalize=True):
    return data_dict[string](trainset=trainset, batch_size=batch_size,
                             testset=testset, equal_sampling=equal_sampling,
                             normalize=normalize)

input_channels = {
    'mnist': 1,
    'cifar10': 3,
    'fashion_mnist': 1,
    'digitclutter_2': 1,
    'digitclutter_3': 1
}

outputs = {
    'mnist': 10,
    'cifar10': 10,
    'fashion_mnist': 10,
    'digitclutter_3': 10,
    'digitclutter_2': 10
}

stages = {
    'mnist': 1,
    'cifar10': 3,
    'fashion_mnist': 1,
    'digitclutter_3': 3,
    'digitclutter_2': 3
}
