import numpy as np
import math
import random
from imgaug import augmenters as iaa
import os
from sklearn.cluster import KMeans
from scipy.cluster.hierarchy import linkage, fcluster
from scipy.spatial.distance import pdist, squareform

dataset_name = 'cifar10'

class Coordinator:
    def __init__(self, pcs):
        self.pcs = pcs 
        self.td = 0.1
        self.ta = -0.1
        self.coordinator = []
        self.gamma = 4  # the maximum number for a coordinator can communicate

        output_dir = f"{dataset_name}_Management_Information"

        
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        self.output_dir = output_dir

    def assign_clients(self, balance=True):
        # assign the devices to each coordinator using greedy algorithm
        print("assign_clients start")
        if not balance:
            self.coordinator = [{i} for i in range(self.pcs.size_device)]
            return
        client_pool = set([i for i in range(self.pcs.size_device)])
        while client_pool:
            new_coordinator = set()
            coordinator_label_pool = np.array([])
            while client_pool and len(new_coordinator) < self.gamma:
                select_client, kl_score = None, float('inf')
                for client in client_pool:
                    # print(client)
                    new_kl_score = self.pcs.get_kl_divergence(self.pcs.global_train_label,
                                                             np.hstack([coordinator_label_pool,
                                                                        self.pcs.local_train_label[client]]))
                    # print("self.pcs.global_train_label:",self.pcs.global_train_label)
                    # print("coordinator:",np.hstack([coordinator_label_pool, self.pcs.local_train_label[client]]))
                    # print(client)
                    # print("kl",new_kl_score)
                    if new_kl_score < kl_score:
                        # print(client)
                        select_client = client
                        kl_score = new_kl_score
                new_coordinator.add(select_client)
                # print(new_coordinator)
                coordinator_label_pool = np.hstack([coordinator_label_pool, self.pcs.local_train_label[select_client]])
                client_pool.remove(select_client)
            self.coordinator.append(new_coordinator)
        
       
        print("\nassign_clients 结果：")
        for idx, coordinator_clients in enumerate(self.coordinator):
            print(f"coordinator {idx} 管理的客户端编号: {sorted(list(coordinator_clients))}")
        
       

        with open(os.path.join(self.output_dir, "Coodinator_A_clients.txt"), "w") as f:
            for i, coordinator_clients in enumerate(self.coordinator):
                f.write(f"coordinator {i} manages clients: {sorted(list(coordinator_clients))}\n")



    # 贪心版本
    # def assign_clients_to(self, balance=True):
        
    #     print("开始分配客户端")
        
    #     # 如果不需要平衡分配，直接将每个客户端分配给单独的组
    #     if not balance:
    #         self.coordinator = [{i} for i in range(self.pcs.size_device)]
    #         return
        
    #     client_pool = set([i for i in range(self.pcs.size_device)])  # 所有客户端的集合
        
    #     # 根据每个协调器最多可以有的客户端数进行分组
    #     while client_pool:
    #         new_group = set()  # 当前组初始化为空
    #         group_label_pool = np.array([])  # 用于存储当前组的标签分布
            
    #         # 当客户端池非空且组的大小还未达到最大数量时
    #         while client_pool and len(new_group) < self.gamma:
    #             best_client = None  # 选择最适合的客户端
    #             best_kl_score = float('inf')  # 初始KL散度为无穷大
                
    #             # 遍历每个候选客户端，寻找最佳客户端加入当前组
    #             for client in client_pool:
    #                 # 计算将该客户端加入当前组后的KL散度
    #                 new_kl_score = self.pcs.get_kl_divergence(
    #                     group_label_pool, 
    #                     np.hstack([group_label_pool, self.pcs.local_train_label[client]])
    #                 )
                    
    #                 # 如果该客户端的KL散度更小，则更新选择的客户端
    #                 if new_kl_score < best_kl_score:
    #                     best_client = client
    #                     best_kl_score = new_kl_score
                
    #             # 将选定的客户端加入当前组
    #             new_group.add(best_client)
                
    #             # 更新当前组的标签分布
    #             group_label_pool = np.hstack([group_label_pool, self.pcs.local_train_label[best_client]])
                
    #             # 从客户端池中移除已分配的客户端
    #             client_pool.remove(best_client)
            
    #         # 将当前组添加到协调器列表中
    #         self.coordinator.append(new_group)
        
    #     print("客户端分配完成")# 输出每个协调器管理的客户端编号
    #     print("\nassign_clients_to 结果：")
    #     for idx, group_clients in enumerate(self.coordinator):
    #         print(f"coordinator {idx} 管理的客户端编号: {sorted(list(group_clients))}")
        
    #     with open(os.path.join(self.output_dir, "Coodinator_B_clients.txt"), "w") as f:
    #         for i, coordinator_clients in enumerate(self.coordinator):
    #             f.write(f"coordinator {i} manages clients: {sorted(list(coordinator_clients))}\n")
    
    # K-Means版本
    # def assign_clients_to(self, balance=True):
        
    #     print("开始使用自定义K-Means进行客户端分配(基于KL散度)")
        
    #     num_clusters=33
       
    #     max_iters=100

    #     # 如果不需要平衡分配，直接将每个客户端分配给单独的组
    #     if not balance:
    #         self.coordinator = [{i} for i in range(self.pcs.size_device)]
    #         return
        
    #     # 将每个客户端的标签分布作为特征向量
    #     client_distributions = np.array([self.pcs.local_train_label[i] for i in range(self.pcs.size_device)])

    #     # 初始化聚类中心（随机选择一些客户端的标签分布作为初始中心）
    #     centers = client_distributions[np.random.choice(len(client_distributions), num_clusters, replace=False)]

    #     for it in range(max_iters):
    #         clusters = [[] for _ in range(num_clusters)]  # 每个聚类中的客户端

    #         # 对每个客户端，找到与其KL散度最小的中心
    #         for i, client_dist in enumerate(client_distributions):
    #             min_kl = float('inf')
    #             best_center = 0
    #             for j, center in enumerate(centers):
    #                 kl_score = self.pcs.get_kl_divergence(client_dist, center)
    #                 if kl_score < min_kl:
    #                     min_kl = kl_score
    #                     best_center = j
    #             clusters[best_center].append(i)

    #         # 更新聚类中心，取每个簇内客户端分布的平均值
    #         new_centers = []
    #         for cluster in clusters:
    #             if len(cluster) > 0:
    #                 cluster_distributions = np.array([client_distributions[i] for i in cluster])
    #                 # 初始化 KL 散度的最小值和最佳中心
    #                 min_kl_sum = float('inf')
    #                 best_distribution = None
    #                 # 遍历该簇内的每个客户端标签分布，计算它与簇内其他客户端标签分布的 KL 散度总和
    #                 for dist in cluster_distributions:
    #                     kl_sum = sum(self.pcs.get_kl_divergence(dist, other_dist) for other_dist in cluster_distributions)
    #                     # 如果当前标签分布的 KL 散度总和比最小值小，更新最佳中心
    #                     if kl_sum < min_kl_sum:
    #                         min_kl_sum = kl_sum
    #                         best_distribution = dist
    #                     # 将 KL 散度最小的分布作为新中心
    #                 new_centers.append(best_distribution)
    #             else:
    #                 # 如果某个簇为空，随机重新初始化中心
    #                 new_centers.append(client_distributions[np.random.choice(len(client_distributions))])

    #         centers = new_centers

    #     # 聚类完成后，分配客户端到协调器
    #     for cluster in clusters:
    #         self.coordinator.append(set(cluster))

    #     print("客户端分配完成")
    #     print("\nassign_clients_to 结果：")
    #     for idx, group_clients in enumerate(self.coordinator):
    #         print(f"coordinator {idx} 管理的客户端编号: {sorted(list(group_clients))}")

    #     with open(os.path.join(self.output_dir, "Coodinator_B_clients.txt"), "w") as f:
    #         for i, coordinator_clients in enumerate(self.coordinator):
    #             f.write(f"coordinator {i} manages clients: {sorted(list(coordinator_clients))}\n")

    # 层次聚类版本
    def assign_clients_to(self, balance=True):
    
        print("开始使用层次聚类进行客户端分配(基于JS散度)")

        num_clusters = 12  # 目标的集群数量
    
        # 如果不需要平衡分配，直接将每个客户端分配给单独的组
        if not balance:
            self.coordinator = [{i} for i in range(self.pcs.size_device)]
            return

        # 将每个客户端的标签分布作为特征向量
        client_distributions = np.array([self.pcs.local_train_label[i] for i in range(self.pcs.size_device)])

        # 构建 JS 散度距离矩阵
        num_clients = len(client_distributions)
        js_distances = np.zeros((num_clients, num_clients))
    
        for i in range(num_clients):
            for j in range(i + 1, num_clients):
                js_score = self.pcs.get_js_divergence(client_distributions[i], client_distributions[j])
                js_distances[i, j] = js_score
                js_distances[j, i] = js_score  # 对称矩阵

        # 将 JS 散度距离矩阵转换为适合层次聚类的格式
        distance_matrix = squareform(js_distances)

        # 使用层次聚类，计算链接矩阵
        linkage_matrix = linkage(distance_matrix, method='average')

        # 根据聚类数量 num_clusters 进行分割
        cluster_labels = fcluster(linkage_matrix, num_clusters, criterion='maxclust')

        # 聚类完成后，分配客户端到协调器
        clusters = [[] for _ in range(num_clusters)]
        for client_id, cluster_id in enumerate(cluster_labels):
            clusters[cluster_id - 1].append(client_id)

        # 将每个集群中的客户端分配给协调器
        for cluster in clusters:
            self.coordinator.append(set(cluster))

        print("客户端分配完成")
        print("\nassign_clients_to 结果：")
        for idx, group_clients in enumerate(self.coordinator):
            print(f"coordinator {idx} 管理的客户端编号: {sorted(list(group_clients))}")

        # 将结果写入文件
        with open(os.path.join(self.output_dir, "Coodinator_B_clients.txt"), "w") as f:
            for i, coordinator_clients in enumerate(self.coordinator):
                f.write(f"coordinator {i} manages clients: {sorted(list(coordinator_clients))}\n")

if __name__ == '__main__':
    print('self balance functions')
