import torch

from codes.components.worker import ByzantineWorker


class MinMaxAttack(ByzantineWorker):
    """
    This class represents the `MinMaxAttack` strategy, a variant of the Byzantine attack on distributed systems.

    Args:
        dev_type (str): Specifies the type of deviation for the attack. Options are 'unit_vec', 'sign', and 'std'.
                       Default is 'unit_vec'.
        *args: Additional positional arguments for the superclass.
        **kwargs: Additional keyword arguments for the superclass.
    """

    def __init__(self, dev_type='unit_vec', *args, **kwargs):
        """
        Initializes the MinMaxAttack instance.

        Attributes:
            dev_type (str): Type of deviation. Dictates how the malicious gradient will be computed.
            lamda (torch.Tensor): A scaling factor for the deviation.
            threshold_diff (float): Threshold difference to ensure convergence of lamda calculations.
            lamda_fail (torch.Tensor): Updated value of lamda when the malicious gradient doesn't meet the criteria.
            lamda_succ (int): Value of lamda when the malicious gradient meets the desired criteria.
        """
        super().__init__(*args, **kwargs)
        self.dev_type = dev_type
        self.lamda = torch.Tensor([50.0]).float().to(self.device)
        # print(lamda)
        self.threshold_diff = 1e-5
        self.lamda_fail = self.lamda
        self.lamda_succ = 0

    def get_gradient(self):
        """
        Returns the gradient value of this worker.

        Returns:
            torch.Tensor: Gradient value.
        """
        return self._gradient

    def omniscient_callback(self):
        """
        Computes and updates the malicious gradient based on other workers' gradients.

        Steps:
        1. Collect gradients from legitimate workers.
        2. Compute mean and standard deviation of these gradients.
        3. Determine the direction of deviation based on 'dev_type'.
        4. Find the malicious gradient update using the deviation and lamda.
        """
        # Step 1: Gather gradients of honest workers
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 0)
        mu = torch.mean(stacked_gradients, 0)
        std = torch.std(stacked_gradients, 0)

        # Step 2: Compute deviation
        if self.dev_type == 'unit_vec':
            deviation = mu / torch.norm(mu)  # unit vector, dir opp to good dir
        elif self.dev_type == 'sign':
            deviation = torch.sign(mu)
        elif self.dev_type == 'std':
            deviation = std
        else:
            deviation = mu / torch.norm(mu)

        # Compute the maximum distance from each gradient to all other gradients
        distances = []
        for update in stacked_gradients:
            distance = torch.norm((stacked_gradients - update), dim=1) ** 2
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
        max_distance = torch.max(distances)
        del distances

        # Step 3: Adjust lamda to compute the desired malicious gradient
        lamda_succ = self.lamda_succ
        lamda_fail = self.lamda_fail
        lamda = self.lamda
        threshold_diff = self.threshold_diff

        while torch.abs(lamda_succ - lamda) > threshold_diff:
            mal_update = (mu - lamda * deviation)
            distance = torch.norm((stacked_gradients - mal_update), dim=1) ** 2
            max_d = torch.max(distance)
            if max_d <= max_distance:
                # print('successful lamda is ', lamda)
                lamda_succ = lamda
                lamda = lamda + lamda_fail / 2
            else:
                lamda = lamda - lamda_fail / 2
            lamda_fail = lamda_fail / 2

        # Step 4: Update the malicious gradient for this worker
        mal_update = (mu - lamda_succ * deviation)

        self._gradient = mal_update

    def set_gradient(self, gradient) -> None:
        """
        Method to set the gradient. This is not implemented for this class since its gradient is maliciously
        computed and shouldn't be externally set.

        Args:
            gradient (torch.Tensor): Gradient value to set.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """
        Method to apply the gradient to update the model. This is not implemented for this class.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError
