import os
import torch
import torchvision
import random
from torchvision import transforms as transforms
import numpy as np
import copy
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import math
from utils.DataLoader import DataLoader
import utils
from abc import abstractmethod, ABCMeta
import collections
from torch.utils.data import ConcatDataset
# class AbstractPartitioner(metaclass=ABCMeta):
#     @abstractmethod
#     def __call__(self, *args, **kwargs):
#         pass
#
# class BasicPartitioner(AbstractPartitioner):
#     """This is the basic class of data partitioner. The partitioner will be directly called by the
#     task generator of different benchmarks. By overwriting __call__ method, different partitioners
#     can be realized. The input of __call__ is usually a dataset.
#     """
#     def __call__(self, *args, **kwargs):
#         return
#
#     def register_generator(self, generator):
#         r"""Register the generator as an self's attribute"""
#         self.generator = generator
#
#     def data_imbalance_generator(self, num_clients, datasize, imbalance=0, minvol=1):
#         r"""
#         Split the data size into several parts
#
#         Args:
#             num_clients (int): the number of clients
#             datasize (int): the total data size
#             imbalance (float): the degree of data imbalance across clients
#             minvol (int): the minimal size of dataset
#         Returns:
#             a list of integer numbers that represents local data sizes
#         """
#         if imbalance == 0:
#             samples_per_client = [int(datasize / num_clients) for _ in range(num_clients)]
#             for _ in range(datasize % num_clients): samples_per_client[_] += 1
#         else:
#             imbalance = max(0.1, imbalance)
#             sigma = imbalance
#             mean_datasize = datasize / num_clients
#             mu = np.log(mean_datasize) - sigma ** 2 / 2.0
#             samples_per_client = np.random.lognormal(mu, sigma, (num_clients)).astype(int)
#             crt_data_size = sum(samples_per_client)
#             total_delta = np.abs(crt_data_size-datasize)
#             thresold = max(int(total_delta/10), 1)
#             delta = min(int(0.1 * thresold), 10)
#             # force current data size to match the total data size
#             while crt_data_size != datasize:
#                 if crt_data_size - datasize >= thresold:
#                     maxid = np.argmax(samples_per_client)
#                     maxvol = samples_per_client[maxid]
#                     new_samples = np.random.lognormal(mu, sigma, (10 * num_clients))
#                     while min(new_samples) > maxvol:
#                         new_samples = np.random.lognormal(mu, sigma, (10 * num_clients))
#                     new_size_id = np.argmin(
#                         [np.abs(crt_data_size - samples_per_client[maxid] + s - datasize) for s in new_samples])
#                     samples_per_client[maxid] = new_samples[new_size_id]
#                 elif crt_data_size - datasize >= delta:
#                     maxid = np.argmax(samples_per_client)
#                     if samples_per_client[maxid]>=delta:
#                         samples_per_client[maxid] -= delta
#                     elif samples_per_client[maxid]>1:
#                         samples_per_client[maxid] -= 1
#                 elif crt_data_size - datasize > 0:
#                     maxid = np.argmax(samples_per_client)
#                     crt_delta = (crt_data_size - datasize)
#                     if samples_per_client[maxid]>=crt_delta:
#                         samples_per_client[maxid] -= crt_delta
#                     elif samples_per_client[maxid]>=minvol:
#                         samples_per_client[maxid] -= (crt_delta-minvol)
#                     else:
#                         warnings.warn("Failed to keep the minvol of clients' training data to be larger than {}".format(minvol))
#                         if samples_per_client[maxid] > 1:
#                             samples_per_client[maxid] -=1
#                         else:
#                             raise RuntimeError("Failed to generate distribution due to the conflicts of imbalance and num_clients. Please try to decrease the imbalance term or decrease the number of clients. ")
#                 elif datasize - crt_data_size >= thresold:
#                     minid = np.argmin(samples_per_client)
#                     minvol = samples_per_client[minid]
#                     new_samples = np.random.lognormal(mu, sigma, (10 * num_clients))
#                     while max(new_samples) < minvol:
#                         new_samples = np.random.lognormal(mu, sigma, (10 * num_clients))
#                     new_size_id = np.argmin(
#                         [np.abs(crt_data_size - samples_per_client[minid] + s - datasize) for s in new_samples])
#                     samples_per_client[minid] = new_samples[new_size_id]
#                 elif datasize - crt_data_size >= delta:
#                     minid = np.argmin(samples_per_client)
#                     samples_per_client[minid] += delta
#                 else:
#                     minid = np.argmin(samples_per_client)
#                     samples_per_client[minid] += (datasize - crt_data_size)
#                 crt_data_size = sum(samples_per_client)
#             # let the minimal data size to be larger than 0
#             while min(samples_per_client)==0:
#                 zero_client_idx = np.argmin(samples_per_client)
#                 maxid = np.argmax(samples_per_client)
#                 samples_per_client[maxid] -=1
#                 samples_per_client[zero_client_idx] += 1
#             assert datasize==sum(samples_per_client) and min(samples_per_client)>0
#         return samples_per_client
#
# class DirichletPartitioner(BasicPartitioner):
#     """`Partition the indices of samples in the original dataset according to Dirichlet distribution of the
#     particular attribute. This way of partition is widely used by existing works in federated learning.
#
#     Args:
#         num_clients (int, optional): the number of clients
#         alpha (float, optional): `alpha`(i.e. alpha>=0) in Dir(alpha*p) where p is the global distribution. The smaller alpha is, the higher heterogeneity the data is.
#         imbalance (float, optional): the degree of imbalance of the amounts of different local_movielens_recommendation data (0<=imbalance<=1)
#         error_bar (float, optional): the allowed error when the generated distribution mismatches the distirbution that is actually wanted, since there may be no solution for particular imbalance and alpha.
#         index_func (func, optional): to index the distribution-dependent (i.e. label) attribute in each sample.
#     """
#     def __init__(self, num_clients=100, alpha=0.1, error_bar=1e-6, imbalance=1.0, index_func=lambda X:[xi[-1] for xi in X], minvol=1):
#         self.num_clients = num_clients
#         self.alpha = alpha
#         self.imbalance = imbalance
#         self.index_func = index_func
#         self.minvol = minvol
#         self.error_bar = error_bar
#
#     def __str__(self):
#         name = "dir{:.2f}_err{}".format(self.alpha, self.error_bar)
#         if self.imbalance > 0: name += '_imb{:.1f}'.format(self.imbalance)
#         return name
#
#     def __call__(self, data):
#         attrs = self.index_func(data)
#         num_attrs = len(set(attrs))
#         samples_per_client = self.data_imbalance_generator(self.num_clients, len(data), self.imbalance, minvol=self.minvol)
#         # count the label distribution
#         lb_counter = collections.Counter(attrs)
#         lb_names = list(lb_counter.keys())
#         p = np.array([1.0 * v / len(data) for v in lb_counter.values()])
#         lb_dict = {}
#         attrs = np.array(attrs)
#         for lb in lb_names:
#             lb_dict[lb] = np.where(attrs == lb)[0]
#         proportions = [np.random.dirichlet(self.alpha * p) for _ in range(self.num_clients)]
#         while np.any(np.isnan(proportions)):
#             proportions = [np.random.dirichlet(self.alpha * p) for _ in range(self.num_clients)]
#         sorted_cid_map = {k: i for k, i in zip(np.argsort(samples_per_client), [_ for _ in range(self.num_clients)])}
#         error_increase_interval = 500
#         max_error = self.error_bar
#         loop_count = 0
#         crt_id = 0
#         crt_error = 100000
#         while True:
#             if loop_count >= error_increase_interval:
#                 loop_count = 0
#                 max_error = max_error * 10
#             # generate dirichlet distribution till ||E(proportion) - P(D)||<=1e-5*self.num_classes
#             mean_prop = np.sum([pi * di for pi, di in zip(proportions, samples_per_client)], axis=0)
#             mean_prop = mean_prop / mean_prop.sum()
#             error_norm = ((mean_prop - p) ** 2).sum()
#             if crt_error - error_norm >= max_error:
#                 print("Error: {:.8f}".format(error_norm))
#                 crt_error = error_norm
#             if error_norm <= max_error:
#                 break
#             excid = sorted_cid_map[crt_id]
#             crt_id = (crt_id + 1) % self.num_clients
#             sup_prop = [np.random.dirichlet(self.alpha * p) for _ in range(self.num_clients)]
#             del_prop = np.sum([pi * di for pi, di in zip(proportions, samples_per_client)], axis=0)
#             del_prop -= samples_per_client[excid] * proportions[excid]
#             for i in range(error_increase_interval - loop_count):
#                 alter_norms = []
#                 for cid in range(self.num_clients):
#                     if np.any(np.isnan(sup_prop[cid])):
#                         continue
#                     alter_prop = del_prop + samples_per_client[excid] * sup_prop[cid]
#                     alter_prop = alter_prop / alter_prop.sum()
#                     error_alter = ((alter_prop - p) ** 2).sum()
#                     alter_norms.append(error_alter)
#                 if min(alter_norms) < error_norm:
#                     break
#             if len(alter_norms) > 0 and min(alter_norms) < error_norm:
#                 alcid = np.argmin(alter_norms)
#                 proportions[excid] = sup_prop[alcid]
#             loop_count += 1
#         local_datas = [[] for _ in range(self.num_clients)]
#         self.dirichlet_dist = []  # for efficiently visualizing
#         for lb in lb_names:
#             lb_idxs = lb_dict[lb]
#             lb_proportion = np.array([pi[lb_names.index(lb)] * si for pi, si in zip(proportions, samples_per_client)])
#             lb_proportion = lb_proportion / lb_proportion.sum()
#             lb_proportion = (np.cumsum(lb_proportion) * len(lb_idxs)).astype(int)[:-1]
#             lb_datas = np.split(lb_idxs, lb_proportion)
#             self.dirichlet_dist.append([len(lb_data) for lb_data in lb_datas])
#             local_datas = [local_data + lb_data.tolist() for local_data, lb_data in zip(local_datas, lb_datas)]
#         self.dirichlet_dist = np.array(self.dirichlet_dist).T
#         for i in range(self.num_clients): np.random.shuffle(local_datas[i])
#         len_dist = [len(d) for d in local_datas]
#         while min(len_dist)<=self.minvol:
#             min_did = np.argmin(len_dist)
#             max_did = np.argmax(len_dist)
#             max_d = local_datas[max_did]
#             min_d = local_datas[min_did]
#             if len(max_d)<=self.minvol:
#                 raise RuntimeError("The number of clients is too large to distribute enough samples to each client when minvol=={}. Please decrease the number of clients".format(self.minvol))
#             min_d.extend(max_d[:1])
#             max_d = max_d[1:]
#             local_datas[min_did] = min_d
#             local_datas[max_did] = max_d
#             len_dist = [len(d) for d in local_datas]
#         self.local_datas = local_datas
#         return local_datas
class Non_iid(Dataset):
    def __init__(self, x, y):
        self.x_data = x.unsqueeze(1).to(torch.float32)
        self.y_data = y.to(torch.int64)
        self.batch_size = 32  # set batchsize in here
        self.cuda_available = torch.cuda.is_available()

    # Return the number of data
    def __len__(self):
        return len(self.x_data)

    # Sampling
    def __getitem__(self):
        idx = np.random.randint(low=0, high=len(self.x_data), size=self.batch_size)  # random_index
        x = self.x_data[idx]
        y = self.y_data[idx]
        if self.cuda_available:
            return x.cuda(), y.cuda()
        else:
            return x, y


class DataLoader_cifar100(DataLoader):
    def __init__(self,
                 dir_a=0.1,
                 batch_size=50,
                 pool_size=100,
                 input_require_shape=[3, -1, -1],
                 shuffle=True,
                 types='dirichlet',
                 recreate=False,
                 params=None,
                 *args,
                 **kwargs):
        if params is not None:
            dir_a = params['dir_a']
            batch_size = params['batch_size']
        name = 'CIFAR100_pool_' + str(pool_size) + '_batchsize_' + str(batch_size) + '_types_' + 'Dirichlet' + '_dir_a_' + str(dir_a)
        nickname = 'cifar100 dirichlet B' + str(batch_size) + ' C' + str(pool_size) + ' types' + str(types)
        super().__init__(name, nickname, pool_size, batch_size, input_require_shape)

        file_path = utils.pool_folder_path + name + '.npy'

        if os.path.exists(file_path) and (recreate == False):
            data_loader = np.load(file_path, allow_pickle=True).item()
            for attr in list(data_loader.__dict__.keys()):
                setattr(self, attr, data_loader.__dict__[attr])
            print('Successfully Read the Data Pool.')
        else:
            transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))])
            trainset = torchvision.datasets.CIFAR100(root=utils.data_folder_path, train=True,
                                                    download=True, transform=transform)
            trainloader = torch.utils.data.DataLoader(trainset, batch_size=trainset.data.shape[0],
                                                      shuffle=True, num_workers=1)
            testset = torchvision.datasets.CIFAR100(root=utils.data_folder_path, train=False,
                                                   download=True, transform=transform)
            testloader = torch.utils.data.DataLoader(testset, batch_size=testset.data.shape[0],
                                                     shuffle=False, num_workers=1)
            totalset = ConcatDataset([trainset, testset])
            totalloader = torch.utils.data.DataLoader(totalset, batch_size=len(totalset),
                                                      shuffle=True, num_workers=1)

            for i, (input_data, targets) in enumerate(trainloader):
                train_input_data = input_data
                train_target_data = targets
            for i, (input_data, targets) in enumerate(testloader):
                test_input_data = input_data
                test_target_data = targets
            for i, (input_data, targets) in enumerate(totalloader):
                total_input_data = input_data
                total_target_data = targets
# todo:这个位置的初始化有问题
            self.cal_data_shape(train_input_data.shape)
            self.target_class_num = 100

            self.global_training_data = []
            self.global_test_data = []
            self.output_size = 100
            self.model4data = 'resnet'
            self.task_name = 'cifar100_classification'
            # for (input_data, targets) in global_training_data:
            #     self.global_training_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
            # for (input_data, targets) in global_test_data:
            #     self.global_test_data.append((input_data.reshape([-1] + self.input_data_shape), targets))
            # todo: redefination
            # self.total_training_number = int((len(trainset) + len(testset)) * 0.8)
            # self.total_test_number = len(testset)

            def visualization(client_idcs):
                plt.figure(figsize=(25, 20))
                label_distribution = [[] for _ in range(100)]
                for c_id, idc in enumerate(client_idcs):
                    for idx in idc:
                        label_distribution[total_target_data[idx]].append(c_id)

                plt.hist(label_distribution, stacked=True,
                         bins=np.arange(-0.5, pool_size + 1.5, 1),
                         label=[i for i in range(100)], rwidth=0.5)
                plt.xticks(np.arange(pool_size), ["Client %d" %
                                                  c_id for c_id in range(pool_size)])
                plt.xlabel("Client ID")
                plt.ylabel("Number of samples")
                plt.legend()
                plt.title("Display Label Distribution on Different Clients")

                plt.savefig(file_path+'.png')

            def create_data_pool(data_pool, input_data, target_data):
                n_classes = 100
                n_clients = pool_size
                # (K, N) 类别标签分布矩阵X，记录每个类别划分到每个client去的比例
                label_distribution = np.random.dirichlet([dir_a] * n_clients, n_classes)
                # (K, ...) 记录K个类别对应的样本索引集合
                class_idcs = [np.argwhere(target_data == y).flatten()
                              for y in range(n_classes)]

                # 记录N个client分别对应的样本索引集合
                client_train_idcs = [[] for _ in range(n_clients)]
                client_test_idcs = [[] for _ in range(n_clients)]
                self.total_training_number = 0
                self.total_test_number = 0
                for k_idcs, fracs in zip(class_idcs, label_distribution):
                    # np.split按照比例fracs将类别为k的样本索引k_idcs划分为了N个子集
                    # i表示第i个client，idcs表示其对应的样本索引集合idcs
                    for i, idcs in enumerate(np.split(k_idcs,(np.cumsum(fracs)[:-1] * len(k_idcs)).astype(int))):
                        num_idcs = len(idcs)
                        random.shuffle(idcs)
                        train_idcs = idcs[:int(num_idcs * 0.8)]
                        test_idcs = idcs[int(num_idcs * 0.8):]
                        self.total_training_number += len(train_idcs)
                        self.total_test_number += len(test_idcs)
                        client_train_idcs[i] += [train_idcs]
                        client_test_idcs[i] += [test_idcs]
                client_train_idcs = [np.concatenate(idcs) for idcs in client_train_idcs]
                client_test_idcs = [np.concatenate(idcs) for idcs in client_test_idcs]
                visualization(client_train_idcs)
                for pool_idx, client_train_idc, client_test_idc in zip(range(pool_size), client_train_idcs, client_test_idcs):
                    local_train_data_number = len(client_train_idc)
                    local_test_data_number = len(client_test_idc)
                    batch_train_data_indices_list = DataLoader.separate_list(client_train_idc, self.batch_size)
                    batch_test_data_indices_list = DataLoader.separate_list(client_test_idc, self.batch_size)
                    local_train_data = []
                    local_test_data = []
                    for batch_data_indices in batch_train_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
                        batch_target_data = target_data[batch_data_indices]
                        local_train_data.append((batch_input_data, batch_target_data))

                    for batch_data_indices in batch_test_data_indices_list:
                        batch_input_data = input_data[batch_data_indices].reshape([-1] + self.input_data_shape).float()
                        batch_target_data = target_data[batch_data_indices]
                        local_test_data.append((batch_input_data, batch_target_data))

                    data_pool[pool_idx]['local_training_data'] = local_train_data
                    data_pool[pool_idx]['local_training_number'] = local_train_data_number
                    data_pool[pool_idx]['local_test_data'] = local_test_data
                    data_pool[pool_idx]['local_test_number'] = local_test_data_number
                    data_pool[pool_idx]['data_name'] = str(pool_idx)


            data_pool = [{} for _ in range(self.pool_size)]

            create_data_pool(data_pool, total_input_data, total_target_data)
            self.data_pool = data_pool
            np.save(file_path, self)

    def allocate(self, client_list):

        # choose_data_pool_item_indices = np.random.choice(list(range(self.pool_size)), len(client_list), replace=False)
        choose_data_pool_item_indices = list(range(self.pool_size))

        for idx, client in enumerate(client_list):
            data_pool_item = self.data_pool[choose_data_pool_item_indices[idx]]
            client.update_data(choose_data_pool_item_indices[idx],
                               data_pool_item['local_training_data'],
                               data_pool_item['local_training_number'],
                               data_pool_item['local_test_data'],
                               data_pool_item['local_test_number']
                               )

