import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns

CHECK = True

def build_partition(dataset_name='mnist', num_clients=10, partition='iid', beta=[]):
    map_dir = './partition/' # map directory
    if partition == 'iid':
        map_path = "{}/{}_M[{}]_{}.txt".format(map_dir, dataset_name, num_clients, partition)
    elif partition == 'dir':
        beta = beta[0]
        map_path = "{}/{}_M[{}]_{}[{}].txt".format(map_dir, dataset_name, num_clients, partition, beta)
    elif partition == 'exdir':
        C, beta = beta
        map_path = "{}/{}_M[{}]_{}[{} {}].txt".format(map_dir, dataset_name, num_clients, partition, C, beta)
    else:
        raise ValueError
    dataidx_map = Partitioner.read_dataidx_map(map_path)
    return dataidx_map

class Partitioner():
    def __init__(self):
        pass

    def partition_data(self):
        r"""Partition data indices to clients.
        Returns:
            dataidx_map (dict): { client id (int): data indices (numpy.ndarray) }, e.g., {0: [0,1,4], 1: [2,3,5]}
        """
        pass
    
    def gen_dataidx_map(self, labels, num_clients, num_classes, map_dir):
        r"""Generate dataidx_map"""
        dataidx_map = self.partition_data(labels, num_clients, num_classes)

        # Check the dataidx_map
        if CHECK == True:
            self.check_dataidx_map(dataidx_map, labels, num_clients, num_classes)
        map_path = "{}/{}_M[{}]_{}.txt".format(map_dir, self.dataset_name, num_clients, self.output_name)
        self.dumpmap(dataidx_map, map_path)

    @classmethod
    def read_dataidx_map(self, map_path):
        dataidx_map = self.loadmap(map_path)
        return dataidx_map
    
    @classmethod
    def check_dataidx_map(cls, dataidx_map=None, labels=None, num_clients=10, num_classes=10):
        r"""Check whether the map is reasonable by extracting some map information.
        Args:
            label_list (numpy.ndarray, list): labels of the whole dataset
        """
        # Count the number of data samples per class per client
        n_sample_per_class_per_client = { cid: [] for cid in range(num_clients) } # cid: client id
        for cid in range(num_clients):
            # number of data samples per class of any one client
            n_sample_per_class_one_client = [ 0 for _ in range(num_classes) ]
            for j in range(len(dataidx_map[cid])):
                n_sample_per_class_one_client[int(labels[dataidx_map[cid][j]])] += 1
            n_sample_per_class_per_client[cid] = n_sample_per_class_one_client
        print("\n***** the number of samples per class per client *****")
        print(n_sample_per_class_per_client)

        # Count the number of samples per client
        n_sample_per_client = []
        for cid in range(num_clients):
            n_sample_per_client.append(sum(n_sample_per_class_per_client[cid]))
        n_sample_per_client = np.array(n_sample_per_client)
        print("\n***** the number of samples per client *****")
        #print(n_sample_per_client.mean(), n_sample_per_client.std())
        print(n_sample_per_client)

        # Count the number of samples per label
        n_sample_per_label = []
        n_client_per_label = []
        for i in range(num_classes):
            n_s = 0 # number of samples of any one label
            n_c = 0 # number of clients of any one label
            for j in range(num_clients):
                n_s = n_s + n_sample_per_class_per_client[j][i]
                n_c = n_c + int(n_sample_per_class_per_client[j][i] != 0)
            n_sample_per_label.append(n_s)
            n_client_per_label.append(n_c)
        n_sample_per_label = np.array(n_sample_per_label)
        n_client_per_label = np.array(n_client_per_label)
        print("\n*****the number of samples per label*****")
        print(n_sample_per_label)
        print("\n*****the number of clients per label*****")
        #print(n_client_per_label.mean(), n_client_per_label.std())
        print(n_client_per_label)
        
        cls.bubble(n_sample_per_class_per_client, num_clients, num_classes)
        #cls.heatmap(n_sample_per_class_per_client, num_clients, num_classes)

    @classmethod
    def bubble(cls, n_sample_per_class_per_client, num_clients, num_classes):
        r"""Draw bubble chart to display the local data distribution.
        Args:
            n_sample_per_class_per_client (set): { client id: [number of samples of Class 0, number of samples of Class 1, ...] } 
        """
        x = []
        for i in range(num_clients):
            x.extend([i for _ in range(num_classes)])

        y = []
        for i in range(num_clients):
            y.extend([j for j in range(num_classes)])

        size = []
        for i in range(len(x)):
            size.append(n_sample_per_class_per_client[x[i]][y[i]])
        size = [i*0.2 for i in size]

        plt.figure()
        plt.scatter(x, y, s=size, beta=1)
        #plt.title(title)
        plt.xlabel("Client ID")
        plt.ylabel("Label")
        #plt.savefig('./raw_partition/{}/{}.png'.format(dataset, title))
        plt.show()
    
    @classmethod
    def heatmap(cls, n_sample_per_class_per_client, num_clients, num_classes):
        r"""Draw heat map to display the local data distribution"""
        num_sample_per_client = []
        heatmap_data = np.zeros((num_classes, num_clients), int)
        for i in range(num_clients):
            heatmap_data[:,i] = np.array(n_sample_per_class_per_client[i])
            num_sample_per_client.append(sum(n_sample_per_class_per_client[i]))
        fig, ax = plt.subplots(figsize=(12, 6))
        ax = sns.heatmap(heatmap_data, ax=ax, annot=True, fmt="d", linewidths=.9, cmap="YlGn",) #
        ax.set_xticklabels(['{}'.format(i) for i in range(num_clients)], rotation=0)
        #ax.set_xticklabels(['{} ({})'.format(i, num_sample_per_client[i]) for i in range(num_clients)], rotation=0)
        ax.set_yticklabels([str(i) for i in range(num_classes)], rotation=0)
        ax.set_xlabel("Client ID", fontsize=15)
        ax.set_ylabel("Label", fontsize=15)
        #ax.set_title(title, fontsize=16)
        #plt.savefig('./raw_partition/{}/{}.png'.format(dataset, title), bbox_inches='tight')
        plt.show()
    
    @classmethod
    def dumpmap(cls, dataidx_map, map_path):
        for i in range(len(dataidx_map)):   
            if isinstance(dataidx_map[i], list) == False:
                dataidx_map[i] = dataidx_map[i].tolist()
        with open(map_path, 'w') as f:
            json.dump(dataidx_map, f)
    
    @classmethod
    def loadmap(cls, map_path):
        with open(map_path, 'r') as f:
            temp = json.load(f)
        # Since `json.load` will form dict{ '0': [] }, instead of dict{ 0: [] },
        # we need to turn dict{ '0': [] } to dict{ 0: [] }
        dataidx_map = dict()
        for i in range(len(temp)):
            dataidx_map[i] = np.array(temp[str(i)]) 
        return dataidx_map


class IIDPartitioner(Partitioner):
    def __init__(self, dataset_name='mnist'):
        super(IIDPartitioner, self).__init__()
        self.name = 'iid'
        self.dataset_name = dataset_name
        self.output_name = self.name
    
    def partition_data(self, labels, num_clients, num_classes):
        num_labels = len(labels)
        idxs = np.random.permutation(num_labels)
        client_idxs = np.array_split(idxs, num_clients)
        dataidx_map = { cid: client_idxs[cid] for cid in range(num_clients) }
        return dataidx_map


class DirPartitioner(Partitioner):
    def __init__(self, dataset_name='mnist', beta=10.0):
        super(DirPartitioner, self).__init__()
        self.name = 'dir'
        self.dataset_name = dataset_name
        self.beta = beta
        self.output_name = '{}[{}]'.format(self.name, self.beta)
        
    def partition_data(self, labels, num_clients, num_classes):
        beta = self.beta
        min_size = 0
        min_require_size = 10 # the minimum size of samples per client is required to be 10 
        num_labels = len(labels)
        labels = np.array(labels)
        
        while min_size < min_require_size:
            idx_per_client = [[] for _ in range(num_clients)] # data sample indices per client
            for k in range(num_classes):
                idx_k = np.where(labels == k)[0] # data sample indices of class k
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(beta, num_clients))
                proportions = np.array([p * (len(idx_j) < num_labels / num_clients) for p, idx_j in zip(proportions, idx_per_client)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_per_client = [idx_j + idx.tolist() for idx_j, idx in zip(idx_per_client, np.split(idx_k, proportions))] 
                min_size = min([len(idx_j) for idx_j in idx_per_client])

        dataidx_map = {}
        for j in range(num_clients):
            np.random.shuffle(idx_per_client[j])
            dataidx_map[j] = idx_per_client[j]
        return dataidx_map


class ExDirPartitioner(Partitioner):
    def __init__(self, dataset_name='mnist', C=10, beta=10.0):
        super(ExDirPartitioner, self).__init__()
        self.name = 'exdir'
        self.dataset_name = dataset_name
        self.C, self.beta = C, beta
        self.output_name = '{}[{} {}]'.format(self.name, self.C, self.beta)
        
    def allocate_classes(self, num_clients, num_classes):
        '''Allocate `C` classes to each client
        Returns:
            clientidx_map (dict): { class id (int): client indices (list) }
        '''
        min_size_per_class = 0
        min_require_size_per_class = max(self.C * num_clients // num_classes // 5, 1)
        while min_size_per_class < min_require_size_per_class:
            clientidx_map = { k: [] for k in range(num_classes) }
            for cid in range(num_clients):
                slected_classes = np.random.choice(range(num_classes), self.C, replace=False)
                for k in slected_classes:
                    clientidx_map[k].append(cid)
            min_size_per_class = min([len(clientidx_map[k]) for k in range(num_classes)])
        return clientidx_map
    
    def partition_data(self, labels, num_clients, num_classes):
        C, beta = self.C, self.beta
        labels = np.array(labels)
        min_size = 0
        min_require_size = 10
        num_labels = len(labels)
        
        clientidx_map = self.allocate_classes(num_clients, num_classes)
        print("\n*****clientidx_map*****")
        print(clientidx_map)
        print("\n*****Number of clients per label*****")
        print([len(clientidx_map[cid]) for cid in range(num_classes)])

        while min_size < min_require_size:
            idx_per_client = [[] for _ in range(num_clients)]
            for k in range(num_classes):
                idx_k = np.where(labels == k)[0]
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(beta, num_clients))
                # Case 1 (original case in Dir): Balance
                proportions = np.array([p * (len(idx_j) < num_labels / num_clients and j in clientidx_map[k]) for j, (p, idx_j) in enumerate(zip(proportions, idx_per_client))])
                # Case 2: Don't balance
                #proportions = np.array([p * (j in label_netidx_map[k]) for j, (p, idx_j) in enumerate(zip(proportions, idx_per_client))])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                
                if proportions[-1] != len(idx_k):
                    for w in range(clientidx_map[k][-1], num_clients-1):
                        proportions[w] = len(idx_k)
                
                idx_per_client = [idx_j + idx.tolist() for idx_j, idx in zip(idx_per_client, np.split(idx_k, proportions))] 
                min_size = min([len(idx_j) for idx_j in idx_per_client])
        
        dataidx_map = {}
        for j in range(num_clients):
            np.random.shuffle(idx_per_client[j])
            dataidx_map[j] = idx_per_client[j]
        return dataidx_map

def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', type=str, default='cifar10', help='dataset name')
    parser.add_argument('-n', type=int, default=5, help='divide into n clients')
    parser.add_argument('--partition', type=str, default='exdir', help='iid')
    parser.add_argument('--balance', type=bool, default=True, help='balanced or imbalanced')
    parser.add_argument('--beta', type=float, default=100.0, help='the beta of dirichlet distribution')
    parser.add_argument('-C', type=int, default=2, help='the classes of pathological partition')
    args = parser.parse_args()
    print(args)
    
    dataset_dir = '../../../datasets/' # the directory path of datasets
    output_dir = '../../partition' # the directory path of outputs
    dataset_name = args.d # the name of the dataset
    num_clients = args.n # number of clients
    partition = args.partition # partition way
    balance = args.balance
    beta = args.beta
    C = args.C

    # Prepare the dataset
    num_class_dict = {'fashionmnist': 10, 'cifar10': 10, 'cinic10': 10, 'ham':7}
    train_dataset, test_dataset = build_dataset(dataset_name=args.d, dataset_dir=dataset_dir)
    # if partitioning the trainining set merely
    labels = [label for _, label in train_dataset]
    num_classes = num_class_dict[dataset_name]

    if partition == 'iid':
        p = IIDPartitioner(dataset_name=dataset_name)
        p.gen_dataidx_map(labels=labels, map_dir=output_dir, num_clients=num_clients, num_classes=num_classes)
    elif partition == 'dir':
        p = DirPartitioner(dataset_name=dataset_name, beta=beta)
        p.gen_dataidx_map(labels=labels, map_dir=output_dir, num_clients=num_clients, num_classes=num_classes)
    elif partition == 'exdir':
        p = ExDirPartitioner(dataset_name=dataset_name, C=C, beta=beta)
        p.gen_dataidx_map(labels=labels, map_dir=output_dir, num_clients=num_clients, num_classes=num_classes)


if __name__ == '__main__':
    from datasets import build_dataset
    main()