import torch

from .base import _BaseAggregator


class RSA(_BaseAggregator):
    r"""
        Implements RSA: Byzantine-Robust Stochastic Aggregation
    """

    def __init__(self, rsa_lambda):
        self.rsa_lambda = rsa_lambda
        super().__init__()

    def __call__(self, server_iterate, local_iterates):
        return self.rsa_lambda * sum(torch.sign(server_iterate - local) for local in local_iterates)

    def __str__(self):
        return f"RSA (rsa_lambda={self.rsa_lambda})"
