import numpy as np

def weighted_average_numpy(all_model_params, avg_weights):
    avg_params = {k: np.zeros_like(all_model_params[0][k], dtype='float32') for k in all_model_params[0].keys()}

    for i in range(avg_weights.shape[0]):
        for k in avg_params.keys():
            if not k.endswith("num_batches_tracked"):  # special treatment for BN counting
                avg_params[k] += avg_weights[i] * all_model_params[i][k]

    return avg_params, avg_weights

def weighted_average_torch(all_model_params, avg_weights, args):
    # Average together models by client_weights in torch so that we can backprop to client_weights
    averaged_params = {}
    param_names = all_model_params[0].keys()

    for param_name in param_names:
        averaged_params[param_name] = sum([
            wt * model_params[param_name]
            for wt, model_params in zip(avg_weights, all_model_params)
        ]).to(args.device)

    return averaged_params, avg_weights


def get_averaged_params(all_model_params, avg_weights, args):
    if any(map(args.algorithm.startswith, ['DecAvg', 'DFedEM', 'Federico'])):
        averaged_params = weighted_average_torch(all_model_params, avg_weights, args)
    else:
        averaged_params = weighted_average_numpy(all_model_params, avg_weights)
    return averaged_params
