import torch

from collections import defaultdict
from typing import Union, Callable

from ..utils import get_vectorized_parameters


class TorchWorker(object):
    """A worker for distributed training.

    Compute gradients locally and store the gradient.
    """

    def __init__(
        self,
        data_loader: torch.utils.data.DataLoader,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        loss_func: torch.nn.modules.loss._Loss,
        device: Union[torch.device, str],
    ):
        self.data_loader = data_loader
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.device = device

        self.running = {}
        self.metrics = {}
        self.state = defaultdict(dict)

    def add_metric(
        self, name: str, callback: Callable[[torch.Tensor, torch.Tensor], float]
    ):
        if name in self.metrics or name in ["loss", "length"]:
            raise KeyError(f"Metrics ({name}) already added.")

        self.metrics[name] = callback

    def add_metrics(self, metrics: dict):
        for name in metrics:
            self.add_metric(name, metrics[name])

    def __str__(self) -> str:
        return "TorchWorker"

    def train_epoch_start(self) -> None:
        self.running["train_loader_iterator"] = iter(self.data_loader)
        self.model.train()

    def compute_gradient(self) -> dict:
        results = {}

        data, target = self.running["train_loader_iterator"].__next__()
        data, target = data.to(self.device), target.to(self.device)
        self.optimizer.zero_grad()
        output = self.model(data)
        loss = self.loss_func(output, target)
        loss.backward()
        self._save_grad()

        self.running["data"] = data
        self.running["target"] = target

        results["loss"] = loss.item()
        results["length"] = len(target)
        results["metrics"] = {}
        for name, metric in self.metrics.items():
            results["metrics"][name] = metric(output, target)
        return results

    def get_gradient(self) -> torch.Tensor:
        return self._get_saved_grad()

    def apply_gradient(self) -> None:
        self.optimizer.step()

    def set_gradient(self, gradient: torch.Tensor) -> None:
        beg = 0
        for p in self.model.parameters():
            end = beg + len(p.grad.view(-1))
            x = gradient[beg:end].reshape_as(p.grad.data)
            p.grad.data = x.clone().detach()
            beg = end

    def _save_grad(self) -> None:
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                param_state = self.state[p]
                param_state["saved_grad"] = torch.clone(p.grad).detach()

    def _get_saved_grad(self) -> torch.Tensor:
        layer_gradients = []
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                param_state = self.state[p]
                layer_gradients.append(param_state["saved_grad"].data.view(-1))
        return torch.cat(layer_gradients)


class WorkerWithMomentum(TorchWorker):
    """
    Note that we use `WorkerWithMomentum` instead of using multiple `torch.optim.Optimizer`
    because we need to explicitly update the `momentum_buffer`.
    """

    def __init__(self, momentum, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.momentum = momentum

    def _save_grad(self) -> None:
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                param_state = self.state[p]
                if "momentum_buffer" not in param_state:
                    param_state["momentum_buffer"] = torch.clone(p.grad).detach()
                else:
                    param_state["momentum_buffer"].mul_(self.momentum).add_(p.grad)

    def _get_saved_grad(self) -> torch.Tensor:
        layer_gradients = []
        for group in self.optimizer.param_groups:
            for p in group["params"]:
                param_state = self.state[p]
                layer_gradients.append(param_state["momentum_buffer"].data.view(-1))
        return torch.cat(layer_gradients)

    def __str__(self) -> str:
        return "WorkerWithMomentum"


class ByzantineWorker(TorchWorker):
    def configure(self, simulator):
        # call configure after defining DistributedSimulator
        self.simulator = simulator
        simulator.register_omniscient_callback(self.omniscient_callback)

    def compute_gradient(self) -> dict:
        # Use self.simulator to get all other workers
        # Note that the byzantine worker does not modify the states directly.
        return super().compute_gradient()

    def get_gradient(self) -> torch.Tensor:
        # Use self.simulator to get all other workers
        return super().get_gradient()

    def omniscient_callback(self):
        raise NotImplementedError

    def __str__(self) -> str:
        return "ByzantineWorker"


class RSAWorker(TorchWorker):

    def __init__(self, rsa_lambda, *args, **kwargs):
        self.rsa_lambda = rsa_lambda
        super().__init__(*args, **kwargs)

    def __str__(self) -> str:
        return "RSAWorker"

    def get_update(self, server_iterate):
        local_iterate = get_vectorized_parameters(self.optimizer)

        results = self.compute_gradient()
        update = self.get_gradient()
        update += (self.rsa_lambda * torch.sign(local_iterate - server_iterate))

        # update local iterate
        self.set_gradient(update)
        self.apply_gradient()

        # send the original local iterate (from before the update)
        results["local_iterate"] = local_iterate

        return results


class FLTrustWorker(TorchWorker):

    def __init__(self, param1, *args, **kwargs):
        self.param1 = param1 # placeholder
        super().__init__(*args, **kwargs)

    def __str__(self) -> str:
        return "FLTrustWorker"

    def get_update(self, server_optimizer):
        # Synchronize local model with server model
        # go layer by layer

        for group1, group2 in zip(self.optimizer.param_groups, server_optimizer.param_groups):
            for p1, p2 in zip(group1["params"], group2["params"]):
                p1.data = p2.data.clone().detach()


        results = self.compute_gradient()
        update = self.get_gradient()

        results["local_iterate"] = update # local update (without effect of lr).

        return results
