import copy
import torch
from torch import nn


def FedAvg(w):
    w_avg = copy.deepcopy(w[0])
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k]
        w_avg[k] = torch.div(w_avg[k], len(w))
    return w_avg

def FedWeightedAvg(w, coordinator_data_sizes):
    total_data = sum(coordinator_data_sizes)  
    # print(total_data)
    w_avg = copy.deepcopy(w[0])
    
    for k in w_avg.keys():
        w_avg[k] = w[0][k] * (coordinator_data_sizes[0] / total_data) 
        # print(coordinator_data_sizes[0])
        
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] * (coordinator_data_sizes[i] / total_data) 
            # print(coordinator_data_sizes[i])

    return w_avg