import copy

import torch
import models
from torch.nn.utils import vector_to_parameters, parameters_to_vector
import numpy as np
from copy import deepcopy
from torch.nn import functional as F
import logging
from utils import name_param_to_array,  vector_to_model, vector_to_name_param,product_of_experts,KL_between_normals
from scipy.stats import entropy
import hdbscan

class Aggregation():
    def __init__(self,agent_data_sizes, n_params, poisoned_val_loader, args, writer):
        self.agent_data_sizes = agent_data_sizes
        self.args = args
        self.writer = writer
        self.server_lr = args.server_lr
        self.n_params = n_params
        self.poisoned_val_loader = poisoned_val_loader
        self.cum_net_mov = 0

        # self.model = model

        if self.args.data == "tinyimagenet":
            self.n_cls = 200
        elif self.args.data == "cifar100":
            self.n_cls = 100
        elif self.args.data == "femnist":
            self.n_cls = 62
        else:
            self.n_cls = 10

        if args.method == 'Grace':
            self.dir_global_Z_u = torch.zeros(self.n_cls, 1, self.args.dimZ, dtype=torch.float32, device=self.args.device)
            self.dir_global_Z_sigma = torch.ones(self.n_cls, 1, self.args.dimZ, dtype = torch.float32, device = self.args.device)

        if self.args.selection == True:
            # cucb parameters
            self.reward_est = np.zeros(self.args.num_agents)
            self.reward_bias = np.zeros(self.args.num_agents)
            self.T_pull = np.zeros(self.args.num_agents)
            self.reward_global = np.zeros(self.args.rounds)
            self.reward_client = np.zeros(self.args.num_agents)
            # V_pt statistic
            self.V_pt_avg = np.zeros((self.args.num_agents, self.n_cls))
            self.client_idx_lst = []
            self.T_pull_lst = []

        if self.args.drop == True:
            pass
            # trainable_parameters_list = []
            # trainable_parameters = 0
            # for config in configs:
            #     model, _, freeze_dict = self.extract_fnc(config[0], config[1], self._global_model)

            #     # Drop parameters that are already frozen
            #     for key in freeze_dict:
            #         model.pop(key)
            #     trainable_parameters = self.count_data_footprint(model)
            #     trainable_parameters_list.append(trainable_parameters)

            # trainable_parameters = np.array(trainable_parameters_list)
            # trainable_parameters = trainable_parameters/np.max(trainable_parameters)
            # ranges = np.cumsum(trainable_parameters)
            # ranges = ranges/np.max(ranges)

            # # Pretraining rounds
            # pretraining_rounds = int(configs[0][0]*self.n_rounds)

            # logging.info(f'[SLT]: Pretraining rounds: {pretraining_rounds}')
            # ranges = np.array(ranges*(self.n_rounds - pretraining_rounds), dtype=int).tolist()
            # self._ranges = [pretraining_rounds] + [item + pretraining_rounds for item in ranges]

        
         
    def aggregate_updates(self, global_model, agent_updates_dict,iter,iter_client=None):
        lr_vector = torch.Tensor([self.server_lr]*self.n_params).to(self.args.device)
        if self.args.method != "rlr":
            lr_vector=lr_vector
        else:
            lr_vector, _ = self.compute_robustLR(agent_updates_dict)
        # mask = torch.ones_like(agent_updates_dict[0])
        aggregated_updates = 0
        cur_global_params = parameters_to_vector(
            [global_model.state_dict()[name] for name in global_model.state_dict()]).detach()
        
        if self.args.aggr=='avg':          
            aggregated_updates = self.agg_avg(agent_updates_dict)
        if self.args.aggr== "clip_avg":
            for _id, update in agent_updates_dict.items():
                weight_diff_norm = torch.norm(update).item()
                logging.info(weight_diff_norm)
                update.data = update.data / max(1, weight_diff_norm / 0.1)
            aggregated_updates = self.agg_avg(agent_updates_dict)
            logging.info(torch.norm(aggregated_updates))
        elif self.args.aggr=='comed':
            aggregated_updates = self.agg_comed(agent_updates_dict)
        elif self.args.aggr == 'sign':
            aggregated_updates = self.agg_sign(agent_updates_dict)
        elif self.args.aggr == "krum":
            aggregated_updates = self.agg_krum(agent_updates_dict)
        elif self.args.aggr == "gm":
            aggregated_updates = self.agg_gm(agent_updates_dict,cur_global_params)
        elif self.args.aggr == "tm":
            aggregated_updates = self.agg_tm(agent_updates_dict)
        elif self.args.aggr == "crfl":
            aggregated_updates = self.agg_crfl(agent_updates_dict,iter)
        elif self.args.aggr == 'flame':
            aggregated_updates,clip = self.agg_flame(agent_updates_dict,iter_client)

        neurotoxin_mask = {}
        updates_dict = vector_to_name_param(aggregated_updates, copy.deepcopy(global_model.state_dict()))
        # for name in updates_dict:
        #     updates = updates_dict[name].abs().view(-1)
        #     gradients_length = torch.numel(updates)
        #     # _, indices = torch.topk(-1 * updates, int(gradients_length * self.args.dense_ratio))
            # mask_flat = torch.zeros(gradients_length)
            # mask_flat[indices.cpu()] = 1
            # neurotoxin_mask[name] = (mask_flat.reshape(updates_dict[name].size()))
        global_model_param = global_model.state_dict()
        for key, var in global_model_param.items():
            if key.split('.')[-1] == 'num_batches_tracked':
                        continue
            temp = copy.deepcopy(var)
            temp = temp.normal_(mean=0,std=self.args.noise)
            var += temp

        cur_global_params = parameters_to_vector([ global_model.state_dict()[name] for name in global_model.state_dict()]).detach()
        
        

        new_global_params =  (cur_global_params + lr_vector*aggregated_updates).float()
        vector_to_model(new_global_params, global_model)
        
        # if self.args.aggr == 'flame':
        
        return updates_dict, neurotoxin_mask


    def compute_robustLR(self, agent_updates_dict):

        agent_updates_sign = [torch.sign(update) for update in agent_updates_dict.values()]  
        sm_of_signs = torch.abs(sum(agent_updates_sign))
        mask=torch.zeros_like(sm_of_signs)
        mask[sm_of_signs < self.args.theta] = 0
        mask[sm_of_signs >= self.args.theta] = 1
        sm_of_signs[sm_of_signs < self.args.theta] = -self.server_lr
        sm_of_signs[sm_of_signs >= self.args.theta] = self.server_lr
        return sm_of_signs.to(self.args.device), mask

    def agg_flame(self,agent_updates_dict,iter_client):
        """ classic fed avg """
        update_params = []
        norm_list = np.array([])
        for loc_mdl in agent_updates_dict.items():
            # update_params.append(loc_mdl)
            norm_list = np.append(norm_list,torch.norm(parameters_dict_to_vector(loc_mdl),p=2).item())
        
        clip_value = np.median(norm_list)
        
        sm_updates, total_data = 0, 0
        for _id, update in agent_updates_dict.items():
            if _id in iter_client:
                gama = clip_value/norm_list[_id]
                n_agent_data = self.agent_data_sizes[_id]
                if gama < 1:
                    sm_updates +=  n_agent_data * update*gama
                else:
                    sm_updates +=  n_agent_data * update
                # sm_updates += update
                total_data += n_agent_data
        
        
        return  sm_updates / total_data,clip_value

    def agg_krum(self, agent_updates_dict):
        krum_param_m = 1
        def _compute_krum_score( vec_grad_list, byzantine_client_num):
            krum_scores = []
            num_client = len(vec_grad_list)
            for i in range(0, num_client):
                dists = []
                for j in range(0, num_client):
                    if i != j:
                        dists.append(
                            torch.norm(vec_grad_list[i]- vec_grad_list[j])
                            .item() ** 2
                        )
                dists.sort()  # ascending
                score = dists[0: num_client - byzantine_client_num - 2]
                krum_scores.append(sum(score))
            return krum_scores

        # Compute list of scores
        __nbworkers = len(agent_updates_dict)
        krum_scores = _compute_krum_score(agent_updates_dict, self.args.num_corrupt)
        score_index = torch.argsort(
            torch.Tensor(krum_scores)
        ).tolist()  # indices; ascending
        score_index = score_index[0: krum_param_m]
        return_gradient = [agent_updates_dict[i] for i in score_index]
        return sum(return_gradient)/len(return_gradient)

    def agg_avg(self, agent_updates_dict):
        """ classic fed avg """

        sm_updates, total_data = 0, 0
        for _id, update in agent_updates_dict.items():
            n_agent_data = self.agent_data_sizes[_id]
            sm_updates +=  n_agent_data * update
            # sm_updates += update
            total_data += n_agent_data
        return  sm_updates / total_data
    
    def agg_comed(self, agent_updates_dict):
        agent_updates_col_vector = [update.view(-1, 1) for update in agent_updates_dict.values()]
        concat_col_vectors = torch.cat(agent_updates_col_vector, dim=1)
        return torch.median(concat_col_vectors, dim=1).values
    
    def agg_sign(self, agent_updates_dict):
        """ aggregated majority sign update """
        agent_updates_sign = [torch.sign(update) for update in agent_updates_dict.values()]
        sm_signs = torch.sign(sum(agent_updates_sign))
        return torch.sign(sm_signs)

    # def agg_cpa(self, agent_updates_dict, cur_global_params):
    #     local_global_w_list = []
    #     global_para = cur_global_params
    #     global_critical_dict = {}
    #     for name, val in global_para.items():
    #         if val.dim() in [2, 4]:
    #             critical_weight = torch.abs((prev_global_w[name] - prev_prev_global_w[name]) * prev_global_w[name])
    #             global_critical_dict[name] = critical_weight

            # if name in "feature":
            #     critical_weight = torch.abs((prev_global_w[name] - prev_prev_global_w[name]) * prev_global_w[name])
            #     global_critical_dict[name] = critical_weight

    def agg_crfl(self,agent_updates_dict,iter_c):

        sm_updates, total_data = 0, 0
        for _id, update in agent_updates_dict.items():
            n_agent_data = self.agent_data_sizes[_id]
            sm_updates +=  n_agent_data * update
            total_data += n_agent_data
        
        agg_updates = sm_updates/total_data

        dynamic_thres = iter_c * 0.25 + 4

        if dynamic_thres < self.args.param_clip_thres:
            param_clip_thres = dynamic_thres
        total_norm = torch.norm(agg_updates)

        max_norm = dynamic_thres
        clip_coef = max_norm/(total_norm + 1e-6)
        current_norm = total_norm
        if total_norm > max_norm:
            agg_updates = agg_updates*clip_coef
        noise = torch.randn(agg_updates.size())*0.001
        noise = noise.to(self.args.device)
        agg_updates = agg_updates+noise

        return agg_updates

    # def kl_sim(self, clients):
    #     clients_id = list(clients.keys())
    #     for _,i in enumerate(clients_id):
    #         for cls in range(self.n_cls):
    #             KL_sim_ij_cls = []
    #             clet_cls_i = clients[i].dir_Z_u[cls],clients[i].dir_Z_sigma[cls]
    #             # for j in range(len(clients)):
    #             for _,j in enumerate(clients_id):    
    #                 if i!=j:
    #                     clet_cls_j = clients[j].dir_Z_u[cls],clients[j].dir_Z_sigma[cls]
    #                     KL_sim_ij_cls.append(KL_between_normals(clet_cls_i,clet_cls_j).item())    

    def detect_anomalies(self,client_similarities,clt=True):
        # 将相似性数据转换为PyTorch张量
        similarities = []
        client_ids = []
        if clt:
            for client_id, similarity_list in client_similarities.items():
                avg_sim = 0
                for i in range(len(similarity_list)):
                    avg_sim += similarity_list[i]
                similarities.append(avg_sim)
                client_ids.append(client_id)
        else:
            for client_id, similarity in enumerate(client_similarities):
                similarities.append(similarity)
                client_ids.append(client_id)

        similarities = torch.tensor(similarities)

        # 计算中位数
        med = torch.median(similarities)

        # 计算绝对偏差
        absolute_deviation = torch.abs(similarities - med)

        # 计算中位数绝对偏差（MAD）
        mad = torch.median(absolute_deviation)

        # 定义异常检测阈值
        threshold = 2.0  # 可以根据需要调整这个值

        # 标记异常客户端
        anomalies = {}
        for i, client_id in enumerate(client_ids):
            if torch.any(absolute_deviation[i] > mad * threshold):
                anomalies[client_id] = True
            else:
                anomalies[client_id] = False

        return anomalies

    def outlier(self, clients):

        dir_global_Z_u_copy = copy.deepcopy(self.dir_global_Z_u)
        dir_global_Z_sigma_copy = copy.deepcopy(self.dir_global_Z_sigma)
        
        KL_sim_cls = {}
        # outlier_score = {}
        clients_id = list(clients.keys())
        outlier_score = {i:0 for i in range(len(clients_id))}

        for cls in range(self.n_cls):
            # clients_all_Z_u = True
            # clients_all_Z_sigma= True
            KL_clet_sim = []
            # for i in len(clients_id):
            KL_sim_ij = {k:[] for k in range(len(clients_id))}
            KL_global_sim = []
            for _,i in enumerate(clients_id):
            # for i in range(len(clients)):

                # KL_sim_ij = {i:[]}

                clet_cls_i = clients[i].dir_Z_u[cls],clients[i].dir_Z_sigma[cls]
                global_cls = dir_global_Z_u_copy[cls],dir_global_Z_sigma_copy[cls]
                KL_global_sim.append(KL_between_normals(clet_cls_i,global_cls).item()) 
                
                # if clients[i].dir_Z_u[cls].equal(torch.zeros(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)) and clients[i].dir_Z_sigma[cls].equal(torch.ones(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)):
                #     pass
                # elif isinstance(clients_all_Z_u, bool):
                #     clients_all_Z_u =clients[i].dir_Z_u[cls].clone().detach()
                #     clients_all_Z_sigma =clients[i].dir_Z_sigma[cls].clone().detach()
                # else:
                #     clients_all_Z_u = torch.cat((clients_all_Z_u, clients[i].dir_Z_u[cls].clone().detach()), 0).clone().detach()
                #     clients_all_Z_sigma = torch.cat((clients_all_Z_sigma, clients[i].dir_Z_sigma[cls].clone().detach()), 0).clone().detach()

                # for j in range(len(clients)):
                for _,j in enumerate(clients_id):    
                    if i!=j:
                        clet_cls_j = clients[j].dir_Z_u[cls],clients[j].dir_Z_sigma[cls]
                        # tmp = KL_between_normals(clet_cls_i,clet_cls_j).item()
                        KL_sim_ij[i].append(KL_between_normals(clet_cls_i,clet_cls_j).item())    
                
                KL_clet_sim.append(KL_sim_ij)
            
            outlier_clnt = self.detect_anomalies(KL_sim_ij)

            print(f'class client {cls} outlier is {outlier_clnt}')
            for i, is_outlier in outlier_clnt.items():
                if is_outlier:
                    outlier_score[i] += 1*0.6

            outlier_ser = self.detect_anomalies(KL_global_sim,clt=False)
            print(f'class server {cls} outlier is {outlier_ser}')
            for i, is_outlier in outlier_ser.items():
                if is_outlier:
                    outlier_score[i] += 1*0.4

            KL_sim_cls[cls] = KL_clet_sim   

            # if not isinstance(clients_all_Z_u, bool):
            #     clients_all_Z = clients_all_Z_u, clients_all_Z_sigma
            #     dir_global_Z_u_copy[cls], dir_global_Z_sigma_copy[cls] = product_of_experts(clients_all_Z)

            #     self.dir_global_Z_u[cls] = (1- self.args.beta2) * self.dir_global_Z_u[cls] + self.args.beta2 * dir_global_Z_u_copy[cls]
            #     self.dir_global_Z_sigma[cls] = (1- self.args.beta2) * self.dir_global_Z_sigma[cls] + self.args.beta2 * dir_global_Z_sigma_copy[cls]
        
        print(f'outlier score is {outlier_score}')

        return outlier_score
    def global_POE(self, clients):

        dir_global_Z_u_copy = copy.deepcopy(self.dir_global_Z_u)
        dir_global_Z_sigma_copy = copy.deepcopy(self.dir_global_Z_sigma)
        
        # KL_sim_cls = {}
        # outlier_score = {}
        clients_id = list(clients.keys())
        # outlier_score = {i:0 for i in range(len(clients_id))}

        for cls in range(self.n_cls):
            clients_all_Z_u = True
            clients_all_Z_sigma= True
            # KL_clet_sim = []
            # for i in len(clients_id):
            # KL_sim_ij = {k:[] for k in range(len(clients_id))}
            # KL_global_sim = []
            for _,i in enumerate(clients_id):
            # for i in range(len(clients)):

                # KL_sim_ij = {i:[]}

                # clet_cls_i = clients[i].dir_Z_u[cls],clients[i].dir_Z_sigma[cls]
                # global_cls = dir_global_Z_u_copy[cls],dir_global_Z_sigma_copy[cls]
                # KL_global_sim.append(KL_between_normals(clet_cls_i,global_cls).item()) 
                
                if clients[i].dir_Z_u[cls].equal(torch.zeros(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)) and clients[i].dir_Z_sigma[cls].equal(torch.ones(1, self.args.dimZ, dtype=torch.float32, device=self.args.device)):
                    pass
                elif isinstance(clients_all_Z_u, bool):
                    clients_all_Z_u =clients[i].dir_Z_u[cls].clone().detach()
                    clients_all_Z_sigma =clients[i].dir_Z_sigma[cls].clone().detach()
                else:
                    clients_all_Z_u = torch.cat((clients_all_Z_u, clients[i].dir_Z_u[cls].clone().detach()), 0).clone().detach()
                    clients_all_Z_sigma = torch.cat((clients_all_Z_sigma, clients[i].dir_Z_sigma[cls].clone().detach()), 0).clone().detach()

                # for j in range(len(clients)):
                # for _,j in enumerate(clients_id):    
                #     if i!=j:
                #         clet_cls_j = clients[j].dir_Z_u[cls],clients[j].dir_Z_sigma[cls]
                #         # tmp = KL_between_normals(clet_cls_i,clet_cls_j).item()
                #         KL_sim_ij[i].append(KL_between_normals(clet_cls_i,clet_cls_j).item())    
                
                # KL_clet_sim.append(KL_sim_ij)
            
            # outlier_clnt = self.detect_anomalies(KL_sim_ij)

            # print(f'class {cls} outlier is {outlier_clnt}')
            # for i, is_outlier in outlier_clnt.items():
            #     if is_outlier:
            #         outlier_score[i] += 1

            # outlier_ser = self.detect_anomalies(KL_global_sim,clt=False)
            # print(f'class {cls} outlier is {outlier_ser}')
            # for i, is_outlier in outlier_ser.items():
            #     if is_outlier:
            #         outlier_score[i] += 1

            # KL_sim_cls[cls] = KL_clet_sim   

            if not isinstance(clients_all_Z_u, bool):
                clients_all_Z = clients_all_Z_u, clients_all_Z_sigma
                dir_global_Z_u_copy[cls], dir_global_Z_sigma_copy[cls] = product_of_experts(clients_all_Z)

                self.dir_global_Z_u[cls] = (1- self.args.beta2) * self.dir_global_Z_u[cls] + self.args.beta2 * dir_global_Z_u_copy[cls]
                self.dir_global_Z_sigma[cls] = (1- self.args.beta2) * self.dir_global_Z_sigma[cls] + self.args.beta2 * dir_global_Z_sigma_copy[cls]
        
        # print(f'outlier score is {outlier_score}')

        # return outlier_score

    def select(self,clients,iter,m):

        RR = int(np.ceil(self.args.num_agents/m))

        if iter < RR:
            client_idx = self.traverse(self.args.num_agents,m,iter)
        else:
            self.reward_bias = self.reward_client + self.args.alpha_sel*np.sqrt(3*np.log(iter)/(2 * self.T_pull))
            client_idx = self.get_client_set(self.reward_bias,m,self.args.num_agents,self.V_pt_avg)

        for s in client_idx:
            self.T_pull[s] += 1

        ra_dict = self.computer_per_client_dir(clients)

        for i in range(m):
            # new: get V_pt i, then calculate reward statistic (mean)
            reward_single, V_pt = 1/self.cross_entropy(ra_dict[client_idx[i]]), ra_dict[client_idx[i]]
            self.reward_client[client_idx[i]] = (self.reward_client[client_idx[i]]*(self.T_pull[client_idx[i]]-1) + reward_single) / self.T_pull[client_idx[i]]
            self.V_pt_avg[client_idx[i]] = np.add((self.V_pt_avg[client_idx[i]]*(self.T_pull[client_idx[i]] - 1)), V_pt) / self.T_pull[client_idx[i]]
                    
        print('{}-round select client are {}'.format(iter,client_idx))
        # get reward of global model
        # self.reward_global[r] = 1/cross_entropy(compute_ratio_per_client_update([global_model], client_idx, aux_loader)[client_idx[0]])
        return client_idx
    
    def select_flame(self,clients,iter,m):

        RR = int(np.ceil(self.args.num_agents/m))

        if iter < RR:
            client_idx = self.traverse(self.args.num_agents,m,iter)
        else:
            self.reward_bias = self.reward_client + self.args.alpha_sel*np.sqrt(3*np.log(iter)/(2 * self.T_pull))
            client_idx = self.get_client_set(self.reward_bias,m,self.args.num_agents,self.V_pt_avg)

        for s in client_idx:
            self.T_pull[s] += 1

        ra_dict = self.computer_per_client_dir(clients)

        for i in range(m):
            # new: get V_pt i, then calculate reward statistic (mean)
            reward_single, V_pt = 1/self.cross_entropy(ra_dict[client_idx[i]]), ra_dict[client_idx[i]]
            self.reward_client[client_idx[i]] = (self.reward_client[client_idx[i]]*(self.T_pull[client_idx[i]]-1) + reward_single) / self.T_pull[client_idx[i]]
            self.V_pt_avg[client_idx[i]] = np.add((self.V_pt_avg[client_idx[i]]*(self.T_pull[client_idx[i]] - 1)), V_pt) / self.T_pull[client_idx[i]]
                    
        print('{}-round select client are {}'.format(iter,client_idx))
        # get reward of global model
        # self.reward_global[r] = 1/cross_entropy(compute_ratio_per_client_update([global_model], client_idx, aux_loader)[client_idx[0]])
        return client_idx

    def traverse(self,N,K,RR):
    # '''sampling each client at least ones'''
        R = int(np.ceil(N/K))
        R_set = np.arange(R)

        selected_set = np.zeros(K)

        if RR in R_set:
            idx = np.where(R_set == RR)[0][0]
            selected_set = np.arange(idx*K, (idx+1)*K)
        for i in range(K):
            if selected_set[i] >= N:
                selected_set[i] = selected_set[i] - N

        return selected_set

    def computer_per_client_dir(self,clients):

        global_Z_distr_u, global_Z_distr_sigma= self.dir_global_Z_u, self.dir_global_Z_sigma

        I_ZX_bound_dict = {}
        
        for i in range(len(clients)):
            
            client_dir_Z_u, client_dir_Z_sigma= clients[i].dir_Z_u, clients[i].dir_Z_sigma
            I_ZX_bound_dict[i] = torch.ones(self.n_cls,dtype=torch.float32, device=self.args.device)
            I_ZX_client_cls = []
            for cls in range(self.n_cls):
                
                clent_cls = self.KL_between_normals(client_dir_Z_u[cls], client_dir_Z_sigma[cls],global_Z_distr_u[cls],global_Z_distr_sigma[cls])
                I_ZX_client_cls.append(clent_cls.item())
            I_ZX_bound_arrary = np.array(torch.tensor(I_ZX_client_cls).cpu().numpy())
            # I_ZX_bound_normalize = self.compute_ratio(I_ZX_bound_arrary)
            I_ZX_bound_dict[i] =  I_ZX_bound_arrary         
        
        return I_ZX_bound_dict

    def compute_ratio(self, grad_square_sum_lst, temp = 1):
        ''' original version in the paper '''
        grad_sum = np.array(grad_square_sum_lst)
    # print(grad_sum)
        grad_sum = grad_sum.min() / grad_sum
        # print(grad_sum)

        # def softmax(grad_sum, temp = 1):
        #     grad_sum = grad_sum - grad_sum.mean()
        #     return np.exp(grad_sum / temp) / np.exp(grad_sum / temp).sum()

        # grad_sum_normalize = softmax(grad_sum, temp)
        grad_sum_normalize = grad_sum / grad_sum.sum()
        # grad_sum_normalize = grad_sum
    
        return grad_sum_normalize
    
    def KL_between_normals(self, q_distr_u, q_distr_sigma, p_distr_u,p_distr_sigma):
        mu_q, sigma_q = q_distr_u, q_distr_sigma
        mu_p, sigma_p = p_distr_u,p_distr_sigma    #Standard Deviation
        k = mu_q.size(1)

        mu_diff = mu_p - mu_q
        mu_diff_sq = torch.mul(mu_diff, mu_diff)
        logdet_sigma_q = torch.sum(2 * torch.log(torch.clamp(sigma_q, min=1e-8)), dim=1)
        logdet_sigma_p = torch.sum(2 * torch.log(torch.clamp(sigma_p, min=1e-8)), dim=1)
        
        fs = torch.sum(torch.div(sigma_q ** 2, sigma_p ** 2), dim=1) + torch.sum(torch.div(mu_diff_sq, sigma_p ** 2), dim=1)
        two_kl = fs - k + logdet_sigma_p - logdet_sigma_q
        two_kl = two_kl * 0.5
        # torch.tanh()
        # two_kl_normal = torch.tanh(two_kl * 0.5)
        # two_kl_sigmoid = torch.sigmoid(two_kl * 0.5)
    
        return two_kl
    def cross_entropy(self,y):
        x = np.ones(y.shape)
        # y = y.data.cpu().numpy()
        return entropy(x) + entropy(x, y + 1e-15)

    def get_client_set(self,reward, K, N, V_pt_avg):
    # '''get client set for cucb'''
    # create dict
        V_pt_dict = {}
        for i in range(N):
            V_pt_dict[i] = V_pt_avg[i,:]
        # choose the max reward client as base
        r_max_index = np.argmax(reward)
        # remove the min index
        V_pt_dict.pop(r_max_index)
        # combination index set S
        S = np.array(r_max_index)
        # combination distribution set
        comb_set = np.array(V_pt_avg[r_max_index])

        while S.size < K:
            ce_reward_set = {}
            for key,value in V_pt_dict.items():
                # calculate the avg class distribution
                comb_dist = np.vstack([comb_set, V_pt_dict[key]])
                comb_dist_avg = np.sum(comb_dist, axis=0) / comb_dist.shape[0]
                # calculate cos loss of combined distribution
                ce_loss = self.cross_entropy_numpy(comb_dist_avg)
                ce_reward_set[key] = 1 / ce_loss

            # get the cos ratio loss index
            reward_max_idx = max(V_pt_dict.keys(),key=(lambda x:ce_reward_set[x]))

            # remove the selected client

            S = np.append(S, reward_max_idx)
            comb_set = np.vstack([comb_set, V_pt_dict[reward_max_idx]])
            V_pt_dict.pop(reward_max_idx)

        return S

    def cross_entropy_numpy(self,y):
        x = np.ones(y.shape)
        # y = y.data.cpu().numpy()
        return entropy(x) + entropy(x, y + 1e-15) 
    
    def FLAME(self,clients,rnd):

        # get update
        update_params = []
        # for loc_mdl in local_model:
        #     update_params.append(get_update(loc_mdl,global_model))
        
        cos = torch.nn.CosineSimilarity(dim=0, eps=1e-6).cuda()
        cos_list=[]
        local_Z_u = []
        local_Z_sigma = []
        for param in clients:
            # local_model_vector.append(parameters_dict_to_vector_flt_cpu(param))
            local_Z_u.append(param.dir_Z_u)
            local_Z_sigma.append(param.dir_Z_sigma)

        for i in range(len(local_Z_u)):
            cos_i = []
            for j in range(len(local_Z_u)):
                cos_ij_u = 1- cos(local_Z_u[i],local_Z_u[j])
                cos_ij_sigma = 1- cos(local_Z_sigma[i],local_Z_sigma[j])
                # cos_i.append(round(cos_ij.item(),4))
                cos_ij = (torch.mean(cos_ij_u)+torch.mean(cos_ij_u)) /2
                cos_i.append(round(cos_ij.item(),4))
            cos_list.append(cos_i)
        similar_file = self.args.folder + '/similiar.txt'
        self.write_file(similar_file,rnd,cos_list)
        num_clients = self.args.num_agents
        num_malicious_clients = self.args.num_corrupt
        num_benign_clients = num_clients - num_malicious_clients
        clusterer = hdbscan.HDBSCAN(min_cluster_size=num_clients//2 + 1,allow_single_cluster=True).fit(cos_list)
        print(clusterer.labels_)
        benign_client = []
        norm_list = np.array([])

        max_num_in_cluster=0
        max_cluster_index=0
        if clusterer.labels_.max() < 0:
            for i in range(len(clients)):
                benign_client.append(i)
                # norm_list = np.append(norm_list,torch.norm(parameters_dict_to_vector(update_params[i]),p=2).item())
        else:
            for index_cluster in range(clusterer.labels_.max()+1):
                if len(clusterer.labels_[clusterer.labels_==index_cluster]) > max_num_in_cluster:
                    max_cluster_index = index_cluster
                    max_num_in_cluster = len(clusterer.labels_[clusterer.labels_==index_cluster])
            for i in range(len(clusterer.labels_)):
                if clusterer.labels_[i] == max_cluster_index:
                    benign_client.append(i)
                    # norm_list = np.append(norm_list,torch.norm(update_params_vector[i],p=2))  # consider BN
                    # norm_list = np.append(norm_list,torch.norm(parameters_dict_to_vector(update_params[i]),p=2).item())  # no consider BN
        print(f'benign client are {benign_client}')
    
        # for i in range(len(benign_client)):
        #     if benign_client[i] < num_malicious_clients:
        #         args.wrong_mal+=1
        #     else:
        #         #  minus per benign in cluster
        #         args.right_ben += 1
        # args.turn+=1
        # print('proportion of malicious are selected:',args.wrong_mal/(num_malicious_clients*args.turn))
        # print('proportion of benign are selected:',args.right_ben/(num_benign_clients*args.turn))
        
        # clip_value = np.median(norm_list)
        # for i in range(len(benign_client)):
        #     gama = clip_value/norm_list[i]
        #     if gama < 1:
        #         for key in update_params[benign_client[i]]:
        #             if key.split('.')[-1] == 'num_batches_tracked':
        #                 continue
        #             update_params[benign_client[i]][key] *= gama
        # global_model = no_defence_balance([update_params[i] for i in benign_client], global_model)
        # #add noise
        # for key, var in global_model.items():
        #     if key.split('.')[-1] == 'num_batches_tracked':
        #                 continue
        #     temp = copy.deepcopy(var)
        #     temp = temp.normal_(mean=0,std=args.noise*clip_value)
        #     var += temp
        return benign_client
    def write_file(self,filename, main_accuracy, backdoor_accuaracy):
        f = open(filename, "a")
        f.write("main_task_accuracy=")
        f.write(str(main_accuracy))
        f.write('\n')
        f.write("backdoor_task_accuracy=")
        f.write(str(backdoor_accuaracy))
        f.write('\n')


def parameters_dict_to_vector(net_dict) -> torch.Tensor:
    r"""Convert parameters to one vector

    Args:
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.

    Returns:
        The parameters represented by a single vector
    """
    vec = []
    for key, param in net_dict.items():
        if key.split('.')[-1] != 'weight' and key.split('.')[-1] != 'bias':
            continue
        vec.append(param.view(-1))
    return torch.cat(vec)

def vector_to_parameters_dict(vec: torch.Tensor, net_dict) -> None:
    r"""Convert one vector to the parameters

    Args:
        vec (Tensor): a single vector represents the parameters of a model.
        parameters (Iterable[Tensor]): an iterator of Tensors that are the
            parameters of a model.
    """

    pointer = 0
    for param in net_dict.values():
        # The length of the parameter
        num_param = param.numel()
        # Slice the vector, reshape it, and replace the old data of the parameter
        param.data = vec[pointer:pointer + num_param].view_as(param).data

        # Increment the pointer
        pointer += num_param
    return net_dict

