import torch


def aggregate_gradients(server_model, client_models):
    server_model.zero_grad()
    params = zip(*[server_model.named_parameters()] + [client_model.named_parameters() for client_model in client_models])
    for param in params:
        server_param = param[0]
        client_params = [p[1].data for p in param[1:] if p[1].data is not None]

        # Only aggregate if there are client parameters that are not None
        if len(client_params):
            server_param[1].grad = sum(client_params) / len(client_params)


def model_operation(model1, model2, operation=lambda a,b: a+b):
    params = zip(*[model1.named_parameters(), model2.named_parameters()])
    for param in params:
        param[0][1].data = operation(param[0][1].data, param[1][1].data)
    return model1


def copy_weights(server_model, client_model):
    params = zip(*[client_model.named_parameters(), server_model.named_parameters()])
    for client_param, server_param in params:
        client_param[1].data = server_param[1].data


class AverageMeter(object):
  def __init__(self):
    self.reset()

  def reset(self):
    self.avg = 0
    self.sum = 0
    self.count = 0

  def update(self, val):
    self.sum += val
    self.count += 1
    self.avg = self.sum / self.count


def ravel_model_params(model, grads=False):
    """
    Squash model parameters or gradients into a single tensor.
    
    From https://github.com/unc-optimization/FedDR/blob/4097eb447a99c7180388527a2d05974906b77eb1/AsyncFedDR/asyncfeddr/utils/serialization.py
    """
    m_parameter = torch.Tensor([0]).to(next(model.parameters()).device)
    for parameter in list(model.parameters()):
        if grads:
            m_parameter = torch.cat((m_parameter, parameter.grad.view(-1)))
        else:
            m_parameter = torch.cat((m_parameter, parameter.data.view(-1)))
    return m_parameter[1:]


def unravel_model_params(model, parameter_update):
    """
    Assigns parameter_update params to model.parameters.
    This is done by iterating through model.parameters() and assigning the relevant params in parameter_update.
    NOTE: this function manipulates model.parameters.

    From https://github.com/unc-optimization/FedDR/blob/4097eb447a99c7180388527a2d05974906b77eb1/AsyncFedDR/asyncfeddr/utils/serialization.py
    """
    current_index = 0 # keep track of where to read from parameter_update
    for parameter in model.parameters():
        numel = parameter.data.numel()
        size = parameter.data.size()
        parameter.data.copy_(parameter_update[current_index:current_index+numel].view(size))
        current_index += numel
