import torch

# from ..optim import CentralizedSGD, get_gradient_in_1d, set_gradient_in_1d
from codes.components.worker import ByzantineWorker


# class Bitflipping(CentralizedSGD):
#     def __str__(self):
#         return "Bitflipping()"

#     @torch.no_grad()
#     def step(self, closure=None):
#         """Aggregates the gradients and performs a single optimization step.
#         Arguments:
#             closure (callable, optional): A closure that reevaluates the model
#                 and returns the loss.
#         """
#         grad = get_gradient_in_1d(self.model)
#         aggregated = self.updater.update(-grad)
#         set_gradient_in_1d(self.model, aggregated)
#         loss = super(CentralizedSGD, self).step(closure=closure)
#         return loss


# class BitFlippingWorker(ByzantineWorker):
#     def __str__(self) -> str:
#         return "BitFlippingWorker"
#
#     def get_gradient(self):
#         # Use self.simulator to get all other workers
#         # Note that the byzantine worker does not modify the states directly.
#         return -super().get_gradient()


class BitFlippingWorker(ByzantineWorker):
    """
    Represents a Byzantine worker that maliciously flips the sign of the
    average gradient from other honest workers.
    """
    def __init__(self, *args, **kwargs):
        """
        Initializes the BitFlippingWorker instance.
        """
        super().__init__(*args, **kwargs)
        self._gradient = None

    def get_gradient(self):
        """
        Returns the gradient value of this worker.

        Returns:
            torch.Tensor: Gradient value.
        """
        return self._gradient

    def omniscient_callback(self):
        """
        Computes the malicious gradient by negating the average of the gradients
        from all other honest workers.
        """
        # Gather gradients from honest workers
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        # Compute the negative average of gathered gradients and set it as this worker's gradient
        self._gradient = -(sum(gradients)) / len(gradients)

    def __str__(self) -> str:
        """
        Returns a string representation of the worker.

        Returns:
            str: Description of the worker.
        """
        return "BitFlippingWorker"

    def set_gradient(self, gradient) -> None:
        """
        Method to set the gradient. Not implemented for this class since its gradient is maliciously computed.

        Args:
            gradient (torch.Tensor): Gradient value to set.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """
        Method to apply the gradient to update the model. Not implemented for this class.

        Raises:
            NotImplementedError: Always raises since this method should not be called for this class.
        """
        raise NotImplementedError
