import torch

from codes.components.worker import ByzantineWorker


class MinSumAttack(ByzantineWorker):
    """
    This class implements a Byzantine attack strategy named 'MinSumAttack'.
    It is derived from the `ByzantineWorker` class.

    Args:
        dev_type (str): The deviation type for the attack. It can be one of 'unit_vec', 'sign', or 'std'.
                        Default is 'unit_vec'.
        *args: Additional arguments to be passed to the base class.
        **kwargs: Additional keyword arguments to be passed to the base class.
    """

    def __init__(self, dev_type='unit_vec', *args, **kwargs):
        """
        Initialize the MinSumAttack worker.

        Attributes:
            dev_type (str): The type of deviation. Determines the direction of the malicious update.
            lamda (torch.Tensor): A scaling factor for deviation.
            threshold_diff (float): Threshold difference to determine the convergence in lamda calculations.
            lamda_fail (torch.Tensor): Updated value of lamda when the malicious gradient doesn't meet the target.
            lamda_succ (int): Value of lamda when the malicious gradient meets the target.
        """
        super().__init__(*args, **kwargs)
        self.dev_type = dev_type
        self.lamda = torch.Tensor([50.0]).float().to(self.device)
        self.threshold_diff = 1e-5
        self.lamda_fail = self.lamda
        self.lamda_succ = 0

    def get_gradient(self):
        """
        Returns the gradient value of this worker.

        Returns:
            torch.Tensor: Gradient value.
        """
        return self._gradient

    def omniscient_callback(self):
        """
        Calculate and set the malicious gradient based on other workers' gradients.

        Steps:
            1. Aggregate the gradients of good workers.
            2. Compute the average and standard deviation of aggregated gradients.
            3. Determine the deviation direction based on the 'dev_type'.
            4. Compute the malicious gradient update based on the lamda and deviation direction.
        """
        # Step 1: Aggregate gradients of good workers
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 0)
        mu = torch.mean(stacked_gradients, 0)
        std = torch.std(stacked_gradients, 0)

        # Step 2: Compute deviation
        if self.dev_type == 'unit_vec':
            deviation = mu / torch.norm(mu)  # unit vector, dir opp to good dir
        elif self.dev_type == 'sign':
            deviation = torch.sign(mu)
        elif self.dev_type == 'std':
            deviation = std
        else:
            deviation = mu / torch.norm(mu)

        # Compute distances from each gradient to all other gradients
        distances = []
        for update in stacked_gradients:
            distance = torch.norm((stacked_gradients - update), dim=1) ** 2
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
        scores = torch.sum(distances, dim=1)
        min_score = torch.min(scores)
        del distances

        # Step 3: Compute malicious gradient
        lamda_succ = self.lamda_succ
        lamda_fail = self.lamda_fail
        lamda = self.lamda
        threshold_diff = self.threshold_diff

        # Adjust lamda value to compute the optimal malicious gradient
        while torch.abs(lamda_succ - lamda) > threshold_diff:
            mal_update = (mu - lamda * deviation)
            distance = torch.norm((stacked_gradients - mal_update), dim=1) ** 2
            score = torch.sum(distance)

            if score <= min_score:
                # print('successful lamda is ', lamda)
                lamda_succ = lamda
                lamda = lamda + lamda_fail / 2
            else:
                lamda = lamda - lamda_fail / 2

            lamda_fail = lamda_fail / 2

            # print(lamda_succ)
        mal_update = (mu - lamda_succ * deviation)

        # Set the computed malicious gradient
        self._gradient = mal_update

    def set_gradient(self, gradient) -> None:
        """Set the gradient value. This method is not implemented for this class."""
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """Apply the gradient to the model. This method is not implemented for this class."""
        raise NotImplementedError
