from utils.DataLoader import DataLoader
import utils
import os
import torch
import torchvision
import random
from torchvision import transforms as transforms
import numpy as np
import copy


class DataLoader_cifar100_dir(DataLoader):

    def __init__(self,
                 pool_size=100,
                 alpha=0.1,
                 batch_size=100,
                 input_require_shape=[3, -1, -1],
                 shuffle=True,
                 recreate=False,
                 params=None,
                 *args,
                 **kwargs):

        if params is not None:
            pool_size = params['C']
            alpha = params['dir_a']
            batch_size = params['batch_size']

        name = 'CIFAR100_dir_pool_' + str(pool_size) + 'alpha_' + str(alpha) + '_batchsize_' + str(
            batch_size) + '_sort_split_input_require_shape_' + str(input_require_shape)
        nickname = 'cifar100 dir B' + \
            str(batch_size) + ' alpha' + str(alpha) + ' N' + str(pool_size)
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = utils.pool_folder_path + name + '.npy'
        if os.path.exists(file_path) and (recreate == False):
            data_loader = np.load(file_path, allow_pickle=True).item()
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:

            transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
            trainset = torchvision.datasets.CIFAR100(root=utils.data_folder_path, train=True,
                                                     download=True, transform=transform)
            trainloader = torch.utils.data.DataLoader(
                trainset, batch_size=trainset.data.shape[0], shuffle=False, num_workers=1)
            testset = torchvision.datasets.CIFAR100(root=utils.data_folder_path, train=False,
                                                    download=True, transform=transform)
            testloader = torch.utils.data.DataLoader(
                testset, batch_size=testset.data.shape[0], shuffle=False, num_workers=1)

            for _, train_data in enumerate(trainloader, 0):
                trainset.data, trainset.targets = train_data
            for _, test_data in enumerate(testloader, 0):
                testset.data, testset.targets = test_data
            dataset_input = []
            dataset_label = []

            dataset_input.extend(trainset.data.cpu().detach().numpy())
            dataset_input.extend(testset.data.cpu().detach().numpy())
            dataset_label.extend(trainset.targets.cpu().detach().numpy())
            dataset_label.extend(testset.targets.cpu().detach().numpy())


            self.cal_data_shape(trainset.data.shape)
            dataset_input = np.array(dataset_input).reshape([-1] + self.input_data_shape)
            dataset_label = np.array(dataset_label)

            self.target_class_num = 100
            self.total_training_number = len(trainset)
            self.total_test_number = len(testset)

            self.task_name = 'cifar100_classification'
            train_prob = trainset.data.shape[0] / \
                (trainset.data.shape[0] + testset.data.shape[0])


            def separate_data(data, train_prob, num_clients, target_class_num, batch_size, alpha):
                least_samples = batch_size / (1 - train_prob)
                X = [[] for _ in range(num_clients)]
                y = [[] for _ in range(num_clients)]
                statistic = [[] for _ in range(num_clients)]

                dataset_content, dataset_label = data

                dataidx_map = {}
                min_size = 0
                K = target_class_num
                N = len(dataset_label)

                while min_size < target_class_num:
                    idx_batch = [[] for _ in range(num_clients)]
                    for k in range(K):
                        idx_k = np.where(dataset_label == k)[0]
                        np.random.shuffle(idx_k)
                        proportions = np.random.dirichlet(
                            np.repeat(alpha, num_clients))

                        proportions = np.array(
                            [p * (len(idx_j) < N / num_clients) for p, idx_j in zip(proportions, idx_batch)])
                        proportions = proportions / proportions.sum()
                        proportions = (np.cumsum(proportions) *
                                       len(idx_k)).astype(int)[:-1]
                        idx_batch = [idx_j + idx.tolist() for idx_j,
                                                              idx in zip(idx_batch, np.split(idx_k, proportions))]
                        min_size = min([len(idx_j) for idx_j in idx_batch])

                for j in range(num_clients):
                    np.random.shuffle(idx_batch[j])
                    dataidx_map[j] = idx_batch[j]

                for client in range(num_clients):
                    idxs = dataidx_map[client]
                    X[client] = dataset_content[idxs]
                    y[client] = dataset_label[idxs]

                    for i in np.unique(y[client]):
                        statistic[client].append((int(i), int(sum(y[client] == i))))

                del data

                for client in range(num_clients):
                    print(f"Client {client}\t Size of data: {len(X[client])}\t Labels: ", np.unique(
                        y[client]))
                    print(f"\t\t Samples of labels: ", [i for i in statistic[client]])
                    print("-" * 50)

                return X, y, statistic

            def separate_list(input_list, n):
                def separate(input_list, n):
                    for i in range(0, len(input_list), n):
                        yield input_list[i: i + n]

                return list(separate(input_list, n))
            def seperate_batch_data(input_data, target_data, batch_size):
                batch_data_indices_list = separate_list(
                    list(range(len(target_data))), batch_size)
                local_data = []
                for batch_data_indices in batch_data_indices_list:
                    batch_input_data = input_data[batch_data_indices]

                    batch_target_data = target_data[batch_data_indices]
                    local_data.append((batch_input_data, batch_target_data))
                return local_data

            def create_data_pool(X, y, pool_size, shuffle, train_prob, batch_size, target_class_num):
                data_pool = [{} for _ in range(pool_size)]
                raw_test_data_pool = [{} for _ in range(pool_size)]
                for pool_idx in range(pool_size):
                    input_data = torch.Tensor(X[pool_idx]).float()
                    target_data = torch.Tensor(y[pool_idx]).long()
                    if shuffle:
                        indices = list(range(len(target_data)))
                        random.shuffle(indices)
                        input_data = input_data[indices]
                        target_data = target_data[indices]

                    training_input_data = input_data[:int(train_prob * len(target_data))]
                    training_target_data = target_data[:int(train_prob * len(target_data))]
                    test_input_data = input_data[int(train_prob * len(target_data)):]
                    test_target_data = target_data[int(train_prob * len(target_data)):]

                    data_pool[pool_idx]['local_training_data'] = seperate_batch_data(
                        training_input_data, training_target_data, batch_size)
                    data_pool[pool_idx]['local_training_number'] = len(
                        training_target_data)
                    data_pool[pool_idx]['data_name'] = str(pool_idx)
                    data_pool[pool_idx]['local_test_data'] = seperate_batch_data(
                        test_input_data, test_target_data, batch_size)
                    data_pool[pool_idx]['local_test_number'] = len(test_target_data)
                    raw_test_data_pool[pool_idx]['input_data'] = test_input_data
                    raw_test_data_pool[pool_idx]['target_data'] = test_target_data

                return data_pool
            X, y, statistic = separate_data((dataset_input, dataset_label), train_prob, pool_size, self.target_class_num, batch_size=batch_size, alpha=alpha)
            self.statistic = statistic
            self.data_pool = create_data_pool(X, y, pool_size, shuffle, train_prob, batch_size, self.target_class_num)
            np.save(file_path, self)

    def allocate(self, client_list):

        choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list),
                                                         replace=False)
        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number'])



            # dataset_input.extend(trainset.data.cpu().detach().numpy())
            # dataset_input.extend(testset.data.cpu().detach().numpy())
            # dataset_label.extend(trainset.targets.cpu().detach().numpy())
            # dataset_label.extend(testset.targets.cpu().detach().numpy())
            #
            # self.cal_data_shape(trainset.data.shape)
            # dataset_input = np.array(dataset_input).reshape(
            #     [-1] + self.input_data_shape)
            # dataset_label = np.array(dataset_label)
            # self.target_class_num = 100
            # self.total_training_number = len(trainset)
            # self.total_test_number = len(testset)
            #
            # train_prob = trainset.data.shape[0] / \
            #     (trainset.data.shape[0] + testset.data.shape[0])
            #
            # X, y, statistic = fp.separate_data((dataset_input, dataset_label), train_prob, pool_size, self.target_class_num,
            #                                    item_classes_num=None, batch_size=batch_size, alpha=alpha, niid=True, balance=None, partition='dir')
            # self.statistic = statistic
            #
            # self.data_pool = fp.create_data_pool(
            #     X, y, pool_size, shuffle, train_prob, batch_size, self.target_class_num)
            #
            # np.save(file_path, self)
