#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @python: 3.6

import copy
import torch
import torch.nn.functional as F
from scipy import linalg
import numpy as np


# #orig
# def aggregate_att(w_clients, w_server, stepsize, metric, dp):
#     """
#     Attentive aggregation
#     :param w_clients: list of client model parameters
#     :param w_server: server model parameters
#     :param stepsize: step size for aggregation
#     :param metric: similarity
#     :param dp: magnitude of randomization
#     :return: updated server model parameters
#     """
#     w_next = copy.deepcopy(w_server)
#     att, att_mat = {}, {}
#     for k in w_server.keys():
#         w_next[k] = torch.zeros_like(w_server[k]).cpu()
#         att[k] = torch.zeros(len(w_clients)).cpu()
#     for k in w_next.keys():
#         for i in range(0, len(w_clients)):
#             att[k][i] = torch.from_numpy(np.array(linalg.norm(w_server[k]-w_clients[i][k], ord=metric)))
#     for k in w_next.keys():
#         att[k] = F.softmax(att[k], dim=0)
#     for k in w_next.keys():
#         att_weight = torch.zeros_like(w_server[k])
#         for i in range(0, len(w_clients)):
#             att_weight += torch.mul(w_server[k]-w_clients[i][k], att[k][i])
#         w_next[k] = w_server[k] - torch.mul(att_weight, stepsize) + torch.mul(torch.randn(w_server[k].shape), dp)
#     return w_next


def aggregate_att(w_clients, w_server, stepsize, metric, dp):
    """
    Attentive aggregation
    :param w_clients: list of client model parameters
    :param w_server: server model parameters
    :param stepsize: step size for aggregation
    :param metric: similarity
    :param dp: magnitude of randomization
    :return: updated server model parameters
    """
    # w_next = copy.deepcopy(w_server)
    # att, att_mat = {}, {}

    # w_next = torch.zeros_like(w_server).cpu()
    w_clients = w_clients.detach().cpu()
    w_server = w_server.detach().cpu()
    att = torch.zeros(len(w_clients)).cpu()

    for i in range(0, len(w_clients)):
        att[i] = torch.from_numpy(np.array(linalg.norm(w_server-w_clients[i], ord=metric)))
        # att[i] = linalg.norm(w_server - w_clients[i], ord=metric)
    att = F.softmax(att, dim=0)
    att_weight = torch.zeros_like(w_server)
    for i in range(0, len(w_clients)):
        att_weight += torch.mul(w_clients[i], att[i])
    w_next = att_weight + torch.mul(torch.randn(w_server.shape), dp)
    return w_next