"""
A better name will be Inner Product Manipulation Attack.
"""

from ..simulators.worker import ByzantineWorker
from ..utils import get_vectorized_parameters


class IPMWorker(ByzantineWorker):
    def __init__(self, epsilon, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epsilon = epsilon
        self._gradient = None

    def get_gradient(self):
        return self._gradient

    def omniscient_callback(self):
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        self._gradient = -self.epsilon * (sum(gradients)) / len(gradients)

    def __str__(self) -> str:
        return "IPMWorker"

    def get_update(self, server_iterate):
        results = {}
        results['loss'] = 0
        results['length'] = 0
        results['metrics'] = {}
        for name, metric in self.metrics.items():
            results["metrics"][name] = 0.0

        update = self.get_gradient()
        if update is None:
            # for the first round, workers haven't yet sent their updates
            # (so send back the server's iterate)
            try:
                server_iterate = get_vectorized_parameters(server_iterate)  # to enable compatibility with fltrust.
                update = server_iterate
            except:
                update = server_iterate

        results['local_iterate'] = update
        return results

