import torch

from codes.components.worker_toy import ByzantineWorker
from codes.components.utils import save_txt


class MinSumAttack(ByzantineWorker):
    """
    Args:

    """

    def __init__(self, save_dir, dev_type='unit_vec', *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dev_type = dev_type
        self.lamda = torch.Tensor([50.0]).float().to(self.device)
        # print(lamda)
        self.threshold_diff = 1e-5
        self.lamda_fail = self.lamda
        self.lamda_succ = 0
        self.save_dir = save_dir
        self.byz_grads = []

    def get_gradient(self):
        self.byz_grads.append(str(self._gradient.clone().detach().cpu().tolist()[0]) + ' ' + str(
            self._gradient.clone().detach().cpu().tolist()[1]))
        save_txt(self.byz_grads, self.save_dir)
        return self._gradient

    def omniscient_callback(self):
        # Loop over good workers and accumulate their gradients
        gradients = []
        for w in self.simulator.workers:
            if not isinstance(w, ByzantineWorker):
                gradients.append(w.get_gradient())

        stacked_gradients = torch.stack(gradients, 0)
        mu = torch.mean(stacked_gradients, 0)
        std = torch.std(stacked_gradients, 0)
        if self.dev_type == 'unit_vec':
            deviation = mu / torch.norm(mu)  # unit vector, dir opp to good dir
        elif self.dev_type == 'sign':
            deviation = torch.sign(mu)
        elif self.dev_type == 'std':
            deviation = std
        else:
            deviation = mu / torch.norm(mu)

        distances = []
        for update in stacked_gradients:
            distance = torch.norm((stacked_gradients - update), dim=1) ** 2
            distances = distance[None, :] if not len(distances) else torch.cat((distances, distance[None, :]), 0)
        scores = torch.sum(distances, dim=1)
        min_score = torch.min(scores)
        del distances

        lamda_succ = self.lamda_succ
        lamda_fail = self.lamda_fail
        lamda = self.lamda
        threshold_diff = self.threshold_diff

        while torch.abs(lamda_succ - lamda) > threshold_diff:
            mal_update = (mu - lamda * deviation)
            distance = torch.norm((stacked_gradients - mal_update), dim=1) ** 2
            score = torch.sum(distance)

            if score <= min_score:
                # print('successful lamda is ', lamda)
                lamda_succ = lamda
                lamda = lamda + lamda_fail / 2
            else:
                lamda = lamda - lamda_fail / 2

            lamda_fail = lamda_fail / 2

            # print(lamda_succ)
        mal_update = (mu - lamda_succ * deviation)

        self._gradient = mal_update

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError
