import math
import random
import numpy as np
import torch, torchvision
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from pathlib import Path

""" Code mostly taken from https://github.com/BerivanIsik/sparse-random-networks/blob/main/src/fedssm/utils/non_iid_cifar.py
to run experiments against it."""

def shuffle_list(data):

    """
        This function returns the shuffled data
    """

    for i in range(len(data)):
        tmp_len= len(data[i][0])
        index = [i for i in range(tmp_len)]
        random.shuffle(index)
        data[i][0], data[i][1] = shuffle_list_data(data[i][0],data[i][1])
    return data

def shuffle_list_data(x, y):
    """
        This function is a helper function, shuffles an
        array while maintaining the mapping between x and y
    """

    inds = list(range(len(x)))
    random.shuffle(inds)
    return x[inds], y[inds]


def get_cifar10():
    """
        Return CIFAR10 train/test data and labels as numpy arrays
    """

    data_train = torchvision.datasets.CIFAR10('data/cifar10', train=True, download=True)
    data_test = torchvision.datasets.CIFAR10('data/cifar10', train=False, download=True)

    x_train, y_train = data_train.data.transpose((0, 1, 2, 3)), np.array(data_train.targets)
    x_test, y_test = data_test.data.transpose((0, 1, 2, 3)), np.array(data_test.targets)

    return x_train, y_train, x_test, y_test


def split_image_data(data, labels, n_clients=100, classes_per_client=10, shuffle=True):
    """
        Splits (data, labels) among 'n_clients s.t. every client can holds 'classes_per_client' number of classes
        Input:
          data : [n_data x shape]
          labels : [n_data (x 1)] from 0 to n_labels
          n_clients : number of clients
          classes_per_client : number of classes per client
          shuffle : True/False => True for shuffling the dataset, False otherwise
        Output:
          clients_split : client data into desired format
    """

    n_data = data.shape[0]
    n_labels = np.max(labels) + 1

    data_per_client = clients_rand(len(data), n_clients)
    data_per_client_per_class = [np.maximum(1, nd // classes_per_client) for nd in data_per_client]

    # sort for labels
    data_idcs = [[] for i in range(n_labels)]
    for j, label in enumerate(labels):
        data_idcs[label] += [j]
    if shuffle:
        for idcs in data_idcs:
            np.random.shuffle(idcs)

    # split data among clients
    clients_split = []
    c = 0
    for i in range(n_clients):
        client_idcs = []

        budget = data_per_client[i]
        c = np.random.randint(n_labels)
        while budget > 0:
            take = min(data_per_client_per_class[i], len(data_idcs[c]), budget)

            client_idcs += data_idcs[c][:take]
            data_idcs[c] = data_idcs[c][take:]

            budget -= take
            c = (c + 1) % n_labels

        clients_split += [[data[client_idcs], labels[client_idcs]]]

    # clients_split = np.array(clients_split)

    return clients_split

def clients_rand(train_len, n_clients):
    """
        This function creates a random distribution
        for the local datasets' size, i.e. number of images each client
        possess.
    """

    client_tmp = np.random.randint(10, 100, n_clients)
    sum_ = np.sum(client_tmp)
    clients_dist = (np.floor((client_tmp / sum_) * train_len)).astype(int)
    to_ret = list(clients_dist)
    to_ret[-1] += (train_len - clients_dist.sum())
    return to_ret

class CustomImageDataset(Dataset):
    '''
    A custom Dataset class for images
    inputs : numpy array [n_data x shape]
    labels : numpy array [n_data (x 1)]
    '''

    def __init__(self, inputs, labels, transforms=None):
        assert inputs.shape[0] == labels.shape[0]
        # self.inputs = torch.Tensor(inputs)
        self.inputs = inputs
        self.labels = torch.Tensor(labels).long()
        self.transforms = transforms

    def __getitem__(self, index):
        img, label = self.inputs[index], self.labels[index]

        if self.transforms is not None:
            img = self.transforms(img)

        # return (img, label)
        return {'img': img, 'label':label}

    def __len__(self):
        return self.inputs.shape[0]
    
def get_default_data_transforms(train=True, verbose=False):
    transforms_train = {
        'general': transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5)])
    }
    transforms_eval = {
        'general': transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=0.5, std=0.5)])
    }
    if verbose:
        print("\nData preprocessing: ")
        for transformation in transforms_train['general'].transforms:
            print(' -', transformation)
        print()

    return transforms_train['general'], transforms_eval['general']


def get_data_loaders(nclients, batch_size, classes_pc=10, real_wd=False):
    x_train, y_train, x_test, y_test = get_cifar10()

    transforms_train, transforms_eval = get_default_data_transforms(verbose=True)

    if real_wd:
        split = split_image_data_realwd(x_train, y_train, n_clients=nclients)
    else:
        split = split_image_data(x_train, y_train, n_clients=nclients,
                                 classes_per_client=classes_pc)

    split_tmp = shuffle_list(split)

    client_loaders = [torch.utils.data.DataLoader(CustomImageDataset(x, y, transforms_train),
                                                  batch_size=batch_size, shuffle=True) for x, y in split_tmp]

    test_loader = torch.utils.data.DataLoader(CustomImageDataset(x_test, y_test, transforms_eval), batch_size=100,
                                              shuffle=False)

    return client_loaders
