from decimal import *

import numpy as np

import torch

from .base import _BaseAggregator


class FLTrust(_BaseAggregator):
    r"""
    This script implements the FLTrust Algorithm

    Cao, X., Fang, M., Liu, J., & Gong, N. Z. (2020).
    Fltrust: Byzantine-robust federated learning via trust bootstrapping.
    arXiv preprint arXiv:2012.13995.
    """

    def __init__(self, param1):
        self.param1 = param1 # placeholder for now.
        super().__init__()

    def __call__(self, local_updates):
        # we treat the first local update as the server update g0

        server_iterate = local_updates[0]
        local_iterates = local_updates[1:]

        # "iterate" refers to gradients. 

        # compute the trust score for each local update
        cos_sim = []
        for local in local_updates:
            cos_sim.append(torch.cosine_similarity(server_iterate, local, dim=0))
        # apply relu (select only positive values)
        cos_sim = torch.nn.functional.relu(torch.tensor(cos_sim))
        normalized_weights = cos_sim / (torch.sum(cos_sim) + 1e-9) # normalize the weights

        # normalize the magnitudes and weight by trust score
        normalized_local_updates = []
        for i, local in enumerate(local_iterates):
            normalized_local_updates.append((local * normalized_weights[i]
            * torch.norm(server_iterate)) / (torch.norm(local)+1e-9))

        # compute the global update
        global_update = torch.sum(torch.stack(normalized_local_updates), dim=0)
        return global_update

    def __str__(self):
        return f"FLTrust"