import torch
import torch.distributed as dist
import numpy as np
from scipy.stats import norm

from codes.components.worker import ByzantineWorker


class ALittleIsEnoughAttack(ByzantineWorker):
    """
    Represents the "A Little Is Enough" attack strategy on distributed systems.

    Args:
        n (int): Total number of workers in the system.
        m (int): Number of Byzantine workers (malicious workers).
        z (int, optional): Number of supporters. Default is computed based on n and m.
    """

    def __init__(self, n, m, z=None, *args, **kwargs):
        """
        Initializes the ALittleIsEnoughAttack instance.
        """
        super().__init__(*args, **kwargs)
        # Number of supporters
        if z is not None:
            self.z_max = z
        else:
            s = np.floor(n / 2 + 1) - m
            cdf_value = (n - m - s) / (n - m)
            self.z_max = norm.ppf(cdf_value)
        self.n_good = n - m

    def get_gradient(self):
        """
        Returns the gradient value of this worker.

        Returns:
            torch.Tensor: Gradient value.
        """
        return self._gradient

    def omniscient_callback(self):
        """
        Computes and updates the malicious gradient based on other workers' gradients.
        """
        # Gather gradients from honest workers
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        # Compute mean and standard deviation
        stacked_gradients = torch.stack(gradients, 1)
        mu = torch.mean(stacked_gradients, 1)
        std = torch.std(stacked_gradients, 1)

        # Update the malicious gradient
        self._gradient = mu - std * self.z_max

    def set_gradient(self, gradient) -> None:
        """
        Method to set the gradient. Not implemented for this class since its gradient is maliciously computed.

        Args:
            gradient (torch.Tensor): Gradient value to set.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """
        Method to apply the gradient to update the model. Not implemented for this class.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError


class OptimALittleIsEnoughAttack(ByzantineWorker):
    """
    Represents the optimized "A Little Is Enough" attack strategy on distributed systems.

    Args:
        n (int): Total number of workers in the system.
        m (int): Number of Byzantine workers (malicious workers).
        z (int, optional): Number of supporters. Default is computed based on n and m.
    """

    def __init__(self, n, m, z=None, *args, **kwargs):
        """
        Initializes the OptimALittleIsEnoughAttack instance.
        """
        super().__init__(*args, **kwargs)
        # Compute the maximum number of supporters, z_max
        # Number of supporters
        if z is not None:
            self.z_max = z
        else:
            s = np.floor(n / 2 + 1) - m
            cdf_value = (n - m - s) / (n - m)
            self.z_max = norm.ppf(cdf_value)
        self.n_good = n - m  # Number of good workers

    def get_gradient(self):
        """
        Returns the gradient value of this worker.

        Returns:
            torch.Tensor: Gradient value.
        """
        return self._gradient

    def omniscient_callback(self):
        """
        Computes and updates the malicious gradient based on other workers' gradients.
        TODO: Integrate an optimization-based approach for "A Little Is Enough" attack.
        """
        # Gather gradients from honest workers
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        # Compute mean and standard deviation
        stacked_gradients = torch.stack(gradients, 1)
        mu = torch.mean(stacked_gradients, 1)
        std = torch.std(stacked_gradients, 1)

        # TODO: Optimization-based ALIE
        self._gradient = mu - std * self.z_max

    def set_gradient(self, gradient) -> None:
        """
        Method to set the gradient. Not implemented for this class since its gradient is maliciously computed.

        Args:
            gradient (torch.Tensor): Gradient value to set.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """
        Method to apply the gradient to update the model. Not implemented for this class.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError
