from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from copy import deepcopy
import numpy as np


class FedAvgAggregator(object):
    def __init__(self, model=None, device='cpu'):
        self.model = model
        self.device = device

    def aggregate(self, agg_info):
        models = agg_info["client_feedback"]
        avg_model = self._para_weighted_avg(models)

        return avg_model

    def _para_weighted_avg(self, models):

        training_set_size = 0
        for i in range(len(models)):
            sample_size, _ = models[i]
            training_set_size += sample_size

        sample_size, avg_model = models[0]
        for key in avg_model:
            for i in range(len(models)):
                local_sample_size, local_model = models[i]
                weight = local_sample_size / training_set_size
                if i == 0:
                    avg_model[key] = np.asarray(local_model[key]) * weight
                else:
                    avg_model[key] += np.asarray(local_model[key]) * weight

        return avg_model

    def update(self, model_parameters):
        '''
        Arguments:
            model_parameters (dict): PyTorch Module object's state_dict.
        '''
        self.model.load_state_dict(model_parameters)
