import logging
import torch
from codes.components.worker import ByzantineWorker


class MimicAttacker(ByzantineWorker):
    def __init__(self, target_rank, *args, **kwargs):
        """
        Initialize the attacker that mimics a specific target worker.

        Args:
            target_rank (int): The rank/index of the target worker to mimic.
        """
        super().__init__(*args, **kwargs)
        self.target_rank = target_rank
        self._gradient = None

    def get_gradient(self):
        """Retrieve the gradient for this worker."""
        return self._gradient

    def omniscient_callback(self):
        """
        In an omniscient setting, get the gradient from the target worker.
        """
        target_worker = self.simulator.workers[self.target_rank]
        self._gradient = target_worker.get_gradient()

    def set_gradient(self, gradient) -> None:
        """Set gradient method (currently not implemented)."""
        raise NotImplementedError

    def apply_gradient(self) -> None:
        """Apply gradient method (currently not implemented)."""
        raise NotImplementedError


class MimicVariantAttacker(ByzantineWorker):
    def __init__(self, warmup, argmax=True, *args, **kwargs):
        """
        Initialize a variant of the mimic attacker.

        Args:
            warmup (int): Number of iterations for warmup phase.
            argmax (bool): Whether to maximize or minimize the dot product during the attack.
        """
        super().__init__(*args, **kwargs)
        self.warmup = warmup
        self.t = 0
        self.mu = None
        self.z = None
        self._gradient = None
        self.argmax = argmax
        self.json_logger = logging.getLogger("stats")
        self.coordinator = None

    def _get_good_grads(self):
        """Retrieve gradients from non-Byzantine (good) workers."""
        good_grads = []
        good_ranks = []
        for i, w in enumerate(self.simulator.workers):
            if not isinstance(w, ByzantineWorker):
                good_grads.append(w.get_gradient())
                good_ranks.append(i)
        return good_ranks, good_grads

    def _init_callback(self, curr_good_grads, curr_avg):
        """Initialize z during the first callback."""
        self.mu = curr_avg
        g = torch.Generator(device=curr_avg.device)
        g.manual_seed(0)
        self.z = torch.randn(len(curr_avg), generator=g, device=curr_avg.device)
        cumu = 0
        for g in curr_good_grads:
            w = (g - self.mu).dot(self.z)
            cumu += w * (g - self.mu)
        self.z = cumu / cumu.norm()

    def _warmup_callback(self, curr_good_grads, curr_avg):
        """Update z in the warmup phase."""
        self.mu = self.t / (1 + self.t) * self.mu + curr_avg / (1 + self.t)

        cumu = 0
        for g in curr_good_grads:
            w = (g - self.mu).dot(self.z)
            cumu += w * (g - self.mu)
        self.z = self.t / (1 + self.t) * self.z + cumu / cumu.norm() / (1 + self.t)
        self.z = self.z / self.z.norm()

    def _attack_callback(self, curr_good_grads):
        """Execute the attack by identifying a target gradient from good gradients."""
        mv = None
        mi = None
        mg = None
        for i, g in enumerate(curr_good_grads):
            d = g.dot(self.z)
            # if self.coordinator:
            #     print(i, d.item())

            if self.argmax:
                if (mv is None) or (d > mv):
                    mv = d
                    mg = g
                    mi = i
            else:
                if (mv is None) or (d < mv):
                    mv = d
                    mg = g
                    mi = i

        # if self.coordinator:
        #     print()
        return mv, mi, mg

    def maybe_setup_coordinator(self):
        """Setup a coordinator worker if it doesn't exist. The coordinator handles logging."""
        if self.coordinator is not None:
            return

        for w in self.simulator.workers:
            if isinstance(w, MimicVariantAttacker):
                self.coordinator = w == self

        if self.coordinator is None:
            raise NotImplementedError

    def omniscient_callback(self):
        """
        In an omniscient setting, update internal states and determine if we're in
        the initialization, warmup, or attack phase.
        """
        self.maybe_setup_coordinator()
        curr_good_ranks, curr_good_grads = self._get_good_grads()
        curr_avg = sum(curr_good_grads) / len(curr_good_grads)

        # Update z and mu
        if self.t == 0:
            self._init_callback(curr_good_grads, curr_avg)
        elif self.t < self.warmup:
            self._warmup_callback(curr_good_grads, curr_avg)

        # Find the target
        if self.t < self.warmup:
            mv, mi, self._gradient = self._attack_callback(curr_good_grads)
            self.target_rank = curr_good_ranks[mi]

            # Coordinator log the output
            if self.coordinator:
                target_rank = curr_good_ranks[mi]
                r = {
                    "_meta": {"type": "mmc_count"},
                    "select": target_rank,
                    "value": mv.item(),
                }
                self.json_logger.info(r)

        else:
            # Fix device
            target_worker = self.simulator.workers[self.target_rank]
            self._gradient = target_worker.get_gradient()

        self.t += 1

    def get_gradient(self):
        """Retrieve the gradient for this worker."""
        return self._gradient

    def set_gradient(self, gradient) -> None:
        raise NotImplementedError

    def apply_gradient(self) -> None:
        raise NotImplementedError
