import torch

# from ..optim import CentralizedSGD, get_gradient_in_1d, set_gradient_in_1d
from codes.components.worker_toy import ByzantineWorker
from codes.components.utils import save_txt


# class Bitflipping(CentralizedSGD):
#     def __str__(self):
#         return "Bitflipping()"

#     @torch.no_grad()
#     def step(self, closure=None):
#         """Aggregates the gradients and performs a single optimization step.
#         Arguments:
#             closure (callable, optional): A closure that reevaluates the model
#                 and returns the loss.
#         """
#         grad = get_gradient_in_1d(self.model)
#         aggregated = self.updater.update(-grad)
#         set_gradient_in_1d(self.model, aggregated)
#         loss = super(CentralizedSGD, self).step(closure=closure)
#         return loss


# class BitFlippingWorker(ByzantineWorker):
#     def __str__(self) -> str:
#         return "BitFlippingWorker"
#
#     def get_gradient(self):
#         # Use self.simulator to get all other workers
#         # Note that the byzantine worker does not modify the states directly.
#         return -super().get_gradient()


class BitFlippingWorker(ByzantineWorker):
    def __init__(self, save_dir, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.save_dir = save_dir
        self._gradient = None
        self.byz_grads = []

    def get_gradient(self):
        self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(self._gradient.clone().detach().cpu().tolist()[1]))
        save_txt(self.byz_grads, self.save_dir)
        return self._gradient

    def omniscient_callback(self):
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        self._gradient = -(sum(gradients)) / len(gradients)

    def __str__(self) -> str:
        return "BitFlippingWorker"

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError
