"""
A better name will be Inner Product Manipulation Attack.
"""

from codes.components.worker import ByzantineWorker


class IPMAttack(ByzantineWorker):
    """
    Represents a Byzantine worker implementing the Inner Product Manipulation (IPM) Attack.
    The worker maliciously modifies the gradient by multiplying it with a negative epsilon
    to disrupt the convergence of the distributed learning system.

    Args:
        epsilon (float): A scalar multiplier for the malicious gradient manipulation.
    """
    def __init__(self, epsilon, *args, **kwargs):
        """
        Initializes the IPMAttack instance.

        Args:
            epsilon (float): Scalar multiplier for gradient manipulation.
        """
        super().__init__(*args, **kwargs)
        self.epsilon = epsilon
        self._gradient = None

    def get_gradient(self):
        """
        Returns the gradient value of this worker.

        Returns:
            torch.Tensor: Gradient value.
        """
        return self._gradient

    def omniscient_callback(self):
        """
        Computes the malicious gradient by modifying the average of the gradients
        from all other honest workers using the negative epsilon.
        """
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        self._gradient = -self.epsilon * (sum(gradients)) / len(gradients)

    def set_gradient(self, gradient) -> None:
        """
        Method to set the gradient. Not implemented for this class since its gradient is maliciously computed.

        Args:
            gradient (torch.Tensor): Gradient value to set.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """
        Method to apply the gradient to update the model. Not implemented for this class.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError


class OptimIPMAttack(ByzantineWorker):
    """
    Represents a Byzantine worker implementing an optimization-based Inner Product Manipulation (IPM) Attack.
    A future version will maliciously optimize the value of epsilon to create a more disruptive gradient.

    Args:
        epsilon (float): Initial scalar multiplier for the malicious gradient manipulation.
    """
    def __init__(self, epsilon, *args, **kwargs):
        """
        Initializes the OptimIPMAttack instance.

        Args:
            epsilon (float): Initial scalar multiplier for gradient manipulation.
        """
        super().__init__(*args, **kwargs)
        self._gradient = None

    def get_gradient(self):
        """
        Returns the gradient value of this worker.

        Returns:
            torch.Tensor: Gradient value.
        """
        return self._gradient

    def omniscient_callback(self):
        """
        Computes the malicious gradient by potentially optimizing the value of epsilon
        and then modifying the average of the gradients from all other honest workers.
        """
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        # TODO: Optimization-based IPM (optimize on epsilon)
        self._gradient = -self.epsilon * (sum(gradients)) / len(gradients)

    def set_gradient(self, gradient) -> None:
        """
        Method to set the gradient. Not implemented for this class since its gradient is maliciously computed.

        Args:
            gradient (torch.Tensor): Gradient value to set.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """
        Method to apply the gradient to update the model. Not implemented for this class.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError
