


import copy
from pydoc import cli
import torch
import numpy as np
from torch import nn



def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        
        for i in range(1, len(w)):
            
            w_avg[k] += w[i][k]
            
        
        w_avg[k] = w_avg[k] / len(w)
    return w_avg


def diver_cal(model_flag, w_g, w_l):
    w_flag = model_flag.state_dict()
    for k in w_flag.keys():
        
        w_flag[k] = w_g[k] - w_l[k]
    model_flag.load_state_dict(w_flag)
    sum_diver = 0
    for param in model_flag.parameters():
        
        se = torch.sum(param**2)
        
        sum_diver += se.detach().cpu().numpy()

    
    
    return sum_diver

'''
def diver_cal(w_g, w_l):
    w_flag = copy.deepcopy(w_g)
    sum_diver = 0
    for k in w_flag.keys():
        diff = sum((w_g[k] - w_l[k])**2)
        se = sum(diff).cpu().numpy()
        sum_diver += se
    
    return sum_diver
'''

def FedAvg_noniid_classifier(w, dict_len):
    model = copy.deepcopy(w[0])
    for i in range(len(w)):
        w[i] = w[i].state_dict()
        
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():        
        w_avg[k] = w_avg[k] * dict_len[0] 
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] * dict_len[i]
            
        
        w_avg[k] = w_avg[k] / sum(dict_len)
    model.load_state_dict(w_avg)
    return model


def cls_norm_agg(w, dict_len, l_heads, distributions):
    

    model = copy.deepcopy(w[0])
    for i in range(len(w)):
        w[i] = w[i].state_dict()
    
    
    w_avg = copy.deepcopy(w[0])
    w_avg = {k: torch.zeros_like(v) for k, v in w_avg.items()}


    
    norm_map = []
    for i in range(len(w)):
        
        norm = torch.norm(l_heads[i].weight, p=2, dim=1)
        
        norm_map.append(norm)
    
    norm_map = torch.stack(norm_map)

    


    distributions = torch.from_numpy(distributions)
    distributions = distributions.to(norm_map.dtype)
    classes = l_heads[0].out_features


    weight_map = copy.deepcopy(distributions)
    for i in range(0, len(w)):
        for c in range(0, classes):
           weight_map[i][c] =  (dict_len[i] / sum(dict_len)) * (distributions[i][c] / torch.sum(distributions, dim=0)[c])
        


    
    for i in range(0, len(w)):  
        for c in range(0, classes):    
            
            
            

            
            

            
            
            
            
            w_avg['weight'][c] += w[i]['weight'][c] * (weight_map[i][c] / torch.sum(weight_map, dim=0)[c])  



    
    for i in range(0, len(w)):  
        for c in range(0, classes):
            
            

            
            
            
            
            w_avg['bias'][c] += w[i]['bias'][c]  * (weight_map[i][c] / torch.sum(weight_map, dim=0)[c])



    model.load_state_dict(w_avg)
    return model


def FedAvg_noniid(w, dict_len):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():        
        w_avg[k] = w_avg[k] * dict_len[0] 
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] * dict_len[i]
            
        
        w_avg[k] = w_avg[k] / sum(dict_len)
    return w_avg

def FedAvg_noniid_class_means(class_means_for_agg, dict_len):
    
    aggregated_means = copy.deepcopy(class_means_for_agg[0])

    
    for k in aggregated_means.keys():
        aggregated_means[k] = aggregated_means[k] * dict_len[0]
        for i in range(1, len(class_means_for_agg)):
            if class_means_for_agg[i] is not None:
                aggregated_means[k] += class_means_for_agg[i][k] * dict_len[i]

        
        aggregated_means[k] = aggregated_means[k] / sum(dict_len)

    return aggregated_means



def FedAvg_Rod(backbone_w_locals, linear_w_locals, dict_len):
    backbone_w_avg = FedAvg_noniid(backbone_w_locals, dict_len)
    linear_w_avg = FedAvg_noniid(linear_w_locals, dict_len)
    return backbone_w_avg, linear_w_avg



def weno_aggeration(w, dict_len, datasetObj, beta, round, start_round = 25):
    
    avg_w = copy.deepcopy(w[0]) 
    
    for k in avg_w.keys():        
        avg_w[k] = avg_w[k] * dict_len[0] 
        for i in range(1, len(w)):
            avg_w[k] += w[i][k] * dict_len[i]
            
        
        avg_w[k] = avg_w[k] / sum(dict_len)
    
    
    
    

    
    
    
    
    
    
    
    
    
    
    

    
    weno_classifier = copy.deepcopy(w[0])
    client_distribution = datasetObj.training_set_distribution
    
    
    client_distribution = client_distribution.astype(np.float64)
    for i in range(len(client_distribution)):
        client_distribution[i] /= sum(client_distribution[i])
    weno_classifier["linear.weight"].zero_()
    weno_classifier["linear.bias"].zero_()
    class_wise_num = [0 for i in range(weno_classifier["linear.bias"].shape[0])]     
    for id_cls in range(weno_classifier["linear.bias"].shape[0]):   
        for id_client in range(len(w)):   
            weno_classifier["linear.weight"][id_cls] += w[id_client]["linear.weight"][id_cls] * client_distribution[id_client][id_cls] * dict_len[id_cls]
            weno_classifier["linear.bias"][id_cls] += w[id_client]["linear.bias"][id_cls] * client_distribution[id_client][id_cls] * dict_len[id_cls]
            class_wise_num[id_cls] += client_distribution[id_client][id_cls] * dict_len[id_cls]
        weno_classifier["linear.weight"][id_cls] / class_wise_num[id_cls]
        weno_classifier["linear.bias"][id_cls] / class_wise_num[id_cls]

    
    if round > start_round:
        avg_w["linear.weight"] = beta * weno_classifier["linear.weight"] + (1 - beta) * avg_w["linear.weight"]
        avg_w["linear.bias"] = beta * weno_classifier["linear.bias"] + (1 - beta) * avg_w["linear.bias"]

    return avg_w

def Weighted_avg_f1(f1_list,dict_len):
    f1_avg = 0
    for i in range(len(dict_len)):
        f1_avg += f1_list[i]*dict_len[i]
    f1_avg = f1_avg/sum(dict_len)
    return f1_avg