"""
A better name will be Inner Product Manipulation Attack.
"""

from codes.components.worker_toy import ByzantineWorker
from codes.components.utils import save_txt


class IPMAttack(ByzantineWorker):
    def __init__(self, epsilon, save_dir, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.epsilon = epsilon
        self._gradient = None
        self.save_dir = save_dir
        self.byz_grads = []

    def get_gradient(self):
        self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(
            self._gradient.clone().detach().cpu().tolist()[1]))
        save_txt(self.byz_grads, self.save_dir)
        return self._gradient

    def omniscient_callback(self):
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        self._gradient = -self.epsilon * (sum(gradients)) / len(gradients)

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError
