# methods for getting data

import random

import torch
from torchvision import datasets, transforms

def get_cifar(
        batch_size=16, download=True, augment=False,
        binarize=False, downsample_params=None):
    trn_kwargs = {
        'batch_size': batch_size,
        'num_workers': 0,
        'shuffle': True,
    }
    test_kwargs = {
        'batch_size': batch_size,
        'num_workers': 0,
        'shuffle': False,
    }

    # augmentation is optional because we should not use it for numerical assumption tests
    # or for training on the MLP
    if augment:
        trn_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])
        test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])
    #elif small_img:
    #    trn_transform = transforms.Compose([
    #            transforms.ToTensor(),
    #            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #            transforms.Resize(18),
    #            ])
    #    test_transform = transforms.Compose([
    #            transforms.ToTensor(),
    #            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    #            transforms.Resize(18),
    #            ])
    else:
        trn_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])
        test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                ])
    trn_ds = datasets.CIFAR10(
            root='./data', train=True, download=download, transform=trn_transform)
    test_ds = datasets.CIFAR10(
            root='./data', train=False, download=download, transform=test_transform)
    if downsample_params is not None:
        downsample_size = downsample_params.get('size', 1000)
        downsample_method = downsample_params.get('method', 'uniform')
        if downsample_method == 'uniform':
            ex_per_class = downsample_size // 10
            valid_indices = {}
            for i in range(10):
                valid_indices[i] = []
            for index in range(len(trn_ds.targets)):
                tgt = int(trn_ds.targets[index])
                if len(valid_indices[tgt]) < ex_per_class:
                    valid_indices[tgt].append(index)
            all_valid_indices = []
            for i in range(10):
                all_valid_indices.extend(valid_indices[i])
        elif downsample_method == 'twoclass':
            ex_per_class = downsample_size // 2
            valid_indices = {}
            for i in [0, 1]:
                valid_indices[i] = []
            for index in range(len(trn_ds.targets)):
                tgt = int(trn_ds.targets[index])
                if (
                        tgt in valid_indices.keys()
                        and len(valid_indices[tgt]) < ex_per_class):
                    valid_indices[tgt].append(index)
            all_valid_indices = []
            for i in [0, 1]:
                all_valid_indices.extend(valid_indices[i])
        else:
            raise NotImplementedError('Not a supported downsample method')
        valid_index_tensor = torch.LongTensor(all_valid_indices)
        trn_ds.data = trn_ds.data[valid_index_tensor, :]
        trn_ds.targets = torch.LongTensor(trn_ds.targets)[valid_index_tensor]
        if binarize and downsample_method is not 'twoclass':
            trn_ds.targets = [int(x>=5) for x in trn_ds.targets]
    elif binarize:
        trn_ds.targets = [int(x>=5) for x in trn_ds.targets]
    trn_dl = torch.utils.data.DataLoader(trn_ds , **trn_kwargs)

    # test dataset is never modified during downsampling
    test_dl = torch.utils.data.DataLoader(test_ds, **test_kwargs)
    return trn_dl, test_dl

def get_mnist(
        batch_size=16, download=True, downsample_params=None):
    trn_kwargs = {
        'batch_size': batch_size,
        'num_workers': 0,
        'shuffle': True,
    }
    test_kwargs = {
        'batch_size': batch_size,
        'num_workers': 0,
        'shuffle': False,
    }
    transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
            ])
    trn_ds = datasets.MNIST(
            root='./data', train=True, download=download, transform=transform)
    test_ds = datasets.MNIST(
            root='./data', train=False, download=download, transform=transform)
    if downsample_params is not None:
        downsample_size = downsample_params.get('size', 1000)
        downsample_method = downsample_params.get('method', 'uniform')
        if downsample_method == 'uniform':
            ex_per_class = downsample_size // 10
            valid_indices = {}
            for i in range(10):
                valid_indices[i] = []
            for index in range(trn_ds.targets.shape[0]):
                tgt = int(trn_ds.targets[index])
                if len(valid_indices[tgt]) < ex_per_class:
                    valid_indices[tgt].append(index)
            all_valid_indices = []
            for i in range(10):
                all_valid_indices.extend(valid_indices[i])
            trn_ds.targets = trn_ds.targets >= 5 # binarize the output classes
        elif downsample_method == 'twoclass':
            ex_per_class = downsample_size // 2
            valid_indices = {}
            for i in [6, 9]:
                valid_indices[i] = []
            for index in range(trn_ds.targets.shape[0]):
                tgt = int(trn_ds.targets[index])
                if (
                        tgt in valid_indices.keys()
                        and len(valid_indices[tgt]) < ex_per_class):
                    valid_indices[tgt].append(index)
            trn_ds.targets[valid_indices[6]] = 0 # binarize the output classes
            trn_ds.targets[valid_indices[9]] = 1
            all_valid_indices = []
            for i in [6, 9]:
                all_valid_indices.extend(valid_indices[i])
        else:
            raise NotImplementedError('Not a supported downsample method')
        all_valid_indices = torch.LongTensor(all_valid_indices)
        trn_ds.data = trn_ds.data[all_valid_indices, :]
        trn_ds.targets = trn_ds.targets[all_valid_indices]
    trn_dl = torch.utils.data.DataLoader(trn_ds , **trn_kwargs)

    # test dataset is never modified during downsampling
    test_dl = torch.utils.data.DataLoader(test_ds, **test_kwargs)
    return trn_dl, test_dl

if __name__=='__main__':
    trn_dl, _ = get_cifar(batch_size=128, augment=True)
    print(len(trn_dl))
