import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader

from data_prepration.common import create_lda_partitions


def get_dataset(args):
    if args.dataset == 'cifar10':
        return get_cifar10(args)
    elif args.dataset == 'mnist':
        return get_mnist(args)
    else:
        raise NotImplementedError


def get_cifar10(args):
    data_dir = args.path
    normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

    test_transform = transforms.Compose([transforms.ToTensor(), normalize])
    train_dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR10(data_dir, train=False, download=True, transform=test_transform)
    train_user_groups = get_lda_distribution(args, train_dataset, test_dataset)
    return train_dataset, test_dataset, train_user_groups


def get_mnist(args):
    data_dir = args.path
    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST(data_dir, train=True, download=True, transform=trans_mnist)
    test_dataset = datasets.MNIST(data_dir, train=False, download=True, transform=trans_mnist)
    train_user_groups = get_lda_distribution(args, train_dataset, test_dataset)
    return train_dataset, test_dataset, train_user_groups


def get_lda_distribution(args, train_dataset, test_dataset):
    train_user_groups, dirichlet_dist = create_lda_partitions(dataset=np.array(train_dataset.targets),
                                                              dirichlet_dist=None,
                                                              num_partitions=args.num_users,
                                                              concentration=args.alpha,
                                                              accept_imbalanced=False,)
    train_user_groups = [train_user_groups[i][0].tolist() for i in range(len(train_user_groups))]
    return train_user_groups


def get_partitioned_data(args):
    train_dataset, test_dataset, train_user_groups = get_dataset(args)
    all_test_data = DataLoader(test_dataset, batch_size=128, shuffle=False)
    clients_train_data = {}
    clients_weights = {}
    for i in range(args.num_users):
        train_dataset_i = Subset(train_dataset, train_user_groups[i])
        clients_train_data[i] = DataLoader(train_dataset_i, batch_size=args.local_bs, shuffle=True)
        clients_weights[i] = len(train_user_groups[i])
    return all_test_data, clients_train_data, clients_weights
