import copy
from pydoc import cli
import torch
import numpy as np
def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = w_avg[k] / len(w)
    return w_avg
def diver_cal(model_flag, w_g, w_l):
    w_flag = model_flag.state_dict()
    for k in w_flag.keys():
        w_flag[k] = w_g[k] - w_l[k]
    model_flag.load_state_dict(w_flag)
    sum_diver = 0
    for param in model_flag.parameters():
        se = torch.sum(param**2)
        sum_diver += se.detach().cpu().numpy()
    return sum_diver
'''
def diver_cal(w_g, w_l):
    w_flag = copy.deepcopy(w_g)
    sum_diver = 0
    for k in w_flag.keys():
        diff = sum((w_g[k] - w_l[k])**2)
        se = sum(diff).cpu().numpy()
        sum_diver += se
    return sum_diver
'''
def FedAvg_noniid(w, dict_len):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():        
        w_avg[k] = w_avg[k] * dict_len[0] 
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] * dict_len[i]
        w_avg[k] = w_avg[k] / sum(dict_len)
    return w_avg
def FedAvg_Rod(backbone_w_locals, linear_w_locals, dict_len):
    backbone_w_avg = FedAvg_noniid(backbone_w_locals, dict_len)
    linear_w_avg = FedAvg_noniid(linear_w_locals, dict_len)
    return backbone_w_avg, linear_w_avg
def weno_aggeration(w, dict_len, datasetObj, beta, round, start_round = 25):
    avg_w = copy.deepcopy(w[0]) 
    for k in avg_w.keys():        
        avg_w[k] = avg_w[k] * dict_len[0] 
        for i in range(1, len(w)):
            avg_w[k] += w[i][k] * dict_len[i]
        avg_w[k] = avg_w[k] / sum(dict_len)
    weno_classifier = copy.deepcopy(w[0])
    client_distribution = datasetObj.training_set_distribution
    client_distribution = client_distribution.astype(np.float64)
    for i in range(len(client_distribution)):
        client_distribution[i] /= sum(client_distribution[i])
    weno_classifier["linear.weight"].zero_()
    weno_classifier["linear.bias"].zero_()
    class_wise_num = [0 for i in range(weno_classifier["linear.bias"].shape[0])]     
    for id_cls in range(weno_classifier["linear.bias"].shape[0]):   
        for id_client in range(len(w)):   
            weno_classifier["linear.weight"][id_cls] += w[id_client]["linear.weight"][id_cls] * client_distribution[id_client][id_cls] * dict_len[id_cls]
            weno_classifier["linear.bias"][id_cls] += w[id_client]["linear.bias"][id_cls] * client_distribution[id_client][id_cls] * dict_len[id_cls]
            class_wise_num[id_cls] += client_distribution[id_client][id_cls] * dict_len[id_cls]
        weno_classifier["linear.weight"][id_cls] / class_wise_num[id_cls]
        weno_classifier["linear.bias"][id_cls] / class_wise_num[id_cls]
    if round > start_round:
        avg_w["linear.weight"] = beta * weno_classifier["linear.weight"] + (1 - beta) * avg_w["linear.weight"]
        avg_w["linear.bias"] = beta * weno_classifier["linear.bias"] + (1 - beta) * avg_w["linear.bias"]
    return avg_w
def Weighted_avg_f1(f1_list,dict_len):
    f1_avg = 0
    for i in range(len(dict_len)):
        f1_avg += f1_list[i]*dict_len[i]
    f1_avg = f1_avg/sum(dict_len)
    return f1_avg