from logging import INFO
from pathlib import Path
from typing import Dict, Optional, Sequence, Tuple

import numpy as np
import torch
import torch.nn as nn
from flwr.common.logger import log
from flwr.common.typing import Config, NDArrays, Scalar
from torch.nn.modules.loss import _Loss
from torch.utils.data import DataLoader

from fl4health.clients.instance_level_privacy_client import InstanceLevelPrivacyClient
from fl4health.clients.numpy_fl_client import NumpyFlClient
from fl4health.parameter_exchange.packing_exchanger import ParameterExchangerWithPacking
from fl4health.utils.metrics import AverageMeter, Meter, Metric

ScaffoldTrainStepOutput = Tuple[torch.Tensor, torch.Tensor]


class ScaffoldClient(NumpyFlClient):
    """
    Federated Learning Client for Scaffold strategy.

    Implementation based on https://arxiv.org/pdf/1910.06378.pdf.
    """

    def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.device) -> None:
        super().__init__(data_path, device)
        self.metrics = metrics
        self.client_control_variates: Optional[NDArrays] = None  # c_i in paper
        self.client_control_variates_updates: Optional[NDArrays] = None  # delta_c_i in paper
        self.server_control_variates: Optional[NDArrays] = None  # c in paper
        self.model: nn.Module
        self.train_loader: DataLoader
        self.val_loader: DataLoader
        self.criterion: _Loss
        self.optimizer: torch.optim.SGD  # Scaffold require vanilla SGD as optimizer
        self.learning_rate_local: float  # eta_l in paper
        self.server_model_state: Optional[NDArrays] = None  # model state from server
        self.server_model_weights: Optional[NDArrays] = None  # x in paper
        self.num_examples: Dict[str, int]
        self.parameter_exchanger: ParameterExchangerWithPacking[NDArrays]

    def fit(self, parameters: NDArrays, config: Config) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
        if not self.initialized:
            self.setup_client(config)

        self.set_parameters(parameters, config)
        local_steps = self.narrow_config_type(config, "local_steps", int)
        # Default SCAFFOLD uses an average meter
        meter = AverageMeter(self.metrics, "global")

        # Scaffold is train by steps by default
        metric_values = self.train_by_steps(local_steps, meter)

        # FitRes should contain local parameters, number of examples on client, and a dictionary holding metrics
        # calculation results.
        return (
            self.get_parameters(config),
            self.num_examples["train_set"],
            metric_values,
        )

    def evaluate(self, parameters: NDArrays, config: Config) -> Tuple[float, int, Dict[str, Scalar]]:
        if not self.initialized:
            self.setup_client(config)

        self.set_parameters(parameters, config)
        # Default SCAFFOLD uses an average meter
        meter = AverageMeter(self.metrics, "global")
        loss, metric_values = self.validate(meter)
        # EvaluateRes should return the loss, number of examples on client, and a dictionary holding metrics
        # calculation results.
        return (
            loss,
            self.num_examples["validation_set"],
            metric_values,
        )

    def get_parameters(self, config: Config) -> NDArrays:
        """
        Packs the parameters and control variartes into a single NDArrays to be sent to the server for aggregation
        """
        assert self.model is not None and self.parameter_exchanger is not None

        model_weights = self.parameter_exchanger.push_parameters(self.model, config=config)

        # Weights and control variates updates sent to server for aggregation
        # Control variates updates sent because only client has access to previous client control variate
        # Therefore it can only be computed locally
        assert self.client_control_variates_updates is not None
        packed_params = self.parameter_exchanger.pack_parameters(model_weights, self.client_control_variates_updates)
        return packed_params

    def set_parameters(self, parameters: NDArrays, config: Config) -> None:
        """
        Assumes that the parameters being passed contain model parameters concatenated with
        server control variates. They are unpacked for the clients to use in training
        """
        assert self.model is not None and self.parameter_exchanger is not None

        server_model_state, server_control_variates = self.parameter_exchanger.unpack_parameters(parameters)
        self.server_control_variates = server_control_variates
        self.server_model_state = server_model_state
        self.parameter_exchanger.pull_parameters(server_model_state, self.model, config)
        self.server_model_weights = [
            model_params.cpu().detach().numpy()
            for model_params in self.model.parameters()
            if model_params.requires_grad
        ]

        # If client control variates do not exist, initialize with zeros as per paper
        if self.client_control_variates is None:
            self.client_control_variates = [np.zeros_like(weight) for weight in self.server_control_variates]

    def update_control_variates(self, local_steps: int) -> None:
        """
        Updates local control variates along with the corresponding updates
        according to the option 2 in Equation 4 in https://arxiv.org/pdf/1910.06378.pdf
        To be called after weights of local model have been updated.
        """
        assert self.client_control_variates is not None
        assert self.server_control_variates is not None
        assert self.server_model_weights is not None
        assert self.learning_rate_local is not None

        # y_i
        client_model_weights = [val.cpu().detach().numpy() for val in self.model.parameters() if val.requires_grad]

        # (x - y_i)
        delta_model_weights = self.compute_parameters_delta(self.server_model_weights, client_model_weights)

        # (c_i - c)
        delta_control_variates = self.compute_parameters_delta(
            self.client_control_variates, self.server_control_variates
        )

        updated_client_control_variates = self.compute_updated_control_variates(
            local_steps, delta_model_weights, delta_control_variates
        )
        self.client_control_variates_updates = self.compute_parameters_delta(
            updated_client_control_variates, self.client_control_variates
        )

        # c_i = c_i^plus
        self.client_control_variates = updated_client_control_variates

    def modify_grad(self) -> None:
        """
        Modifies the gradient of the local model to correct for client drift.
        To be called after the gradients have been computed on a batch of data.
        Updates not applied to params until step is called on optimizer.
        """
        assert self.client_control_variates is not None
        assert self.server_control_variates is not None

        model_params_with_grad = [
            model_params for model_params in self.model.parameters() if model_params.requires_grad
        ]

        for param, client_cv, server_cv in zip(
            model_params_with_grad, self.client_control_variates, self.server_control_variates
        ):
            assert param.grad is not None
            tensor_type = param.grad.dtype
            server_cv_tensor = torch.from_numpy(server_cv).type(tensor_type)
            client_cv_tensor = torch.from_numpy(client_cv).type(tensor_type)
            update = server_cv_tensor.to(self.device) - client_cv_tensor.to(self.device)
            param.grad += update

    def compute_parameters_delta(self, params_1: NDArrays, params_2: NDArrays) -> NDArrays:
        """
        Computes elementwise difference of two lists of NDarray
        where elements in params_2 are subtracted from elements in params_1
        """
        parameter_delta: NDArrays = [param_1 - param_2 for param_1, param_2 in zip(params_1, params_2)]

        return parameter_delta

    def compute_updated_control_variates(
        self, local_steps: int, delta_model_weights: NDArrays, delta_control_variates: NDArrays
    ) -> NDArrays:
        """
        Computes the updated local control variates according to option 2 in Equation 4 of paper
        """

        # coef = 1 / (K * eta_l)
        scaling_coeffient = 1 / (local_steps * self.learning_rate_local)

        # c_i^plus = c_i - c + 1/(K*lr) * (x - y_i)
        updated_client_control_variates = [
            delta_control_variate + scaling_coeffient * delta_model_weight
            for delta_control_variate, delta_model_weight in zip(delta_control_variates, delta_model_weights)
        ]
        return updated_client_control_variates

    def _handle_logging(self, loss: float, metrics_dict: Dict[str, Scalar], is_validation: bool = False) -> None:
        metric_string = "\t".join([f"{key}: {str(val)}" for key, val in metrics_dict.items()])
        metric_prefix = "Validation" if is_validation else "Training"
        log(
            INFO,
            f"Client {metric_prefix} Loss: {loss} \n" f"Client {metric_prefix} Metrics: {metric_string}",
        )

    def train_step(self, input: torch.Tensor, target: torch.Tensor) -> ScaffoldTrainStepOutput:
        # Forward pass on global model and update global parameters
        self.optimizer.zero_grad()
        pred = self.model(input)
        loss = self.criterion(pred, target)
        loss.backward()

        # modify grad to correct for client drift
        self.modify_grad()
        self.optimizer.step()

        return loss, pred

    def train_by_steps(
        self,
        local_steps: int,
        meter: Meter,
    ) -> Dict[str, Scalar]:
        self.model.train()
        running_loss = 0.0
        meter.clear()

        # Pass loader to iterator so we can step through train loader
        train_iterator = iter(self.train_loader)
        for _ in range(local_steps):
            try:
                input, target = next(train_iterator)
            except StopIteration:
                # StopIteration is thrown if dataset ends
                # reinitialize data loader
                train_iterator = iter(self.train_loader)
                input, target = next(train_iterator)

            input, target = input.to(self.device), target.to(self.device)
            loss, pred = self.train_step(input, target)

            running_loss += loss.item()
            meter.update(pred, target)

        running_loss = running_loss / local_steps

        metrics = meter.compute()
        self._handle_logging(running_loss, metrics)
        self.update_control_variates(local_steps)
        return metrics

    def train_by_epochs(self, epochs: int, meter: Meter) -> Dict[str, Scalar]:
        self.model.train()

        for _ in range(epochs):
            meter.clear()
            running_loss = 0.0
            for input, target in self.train_loader:
                input, target = input.to(self.device), target.to(self.device)
                loss, pred = self.train_step(input, target)

                running_loss += loss.item()
                meter.update(pred, target)

            metrics = meter.compute()
            running_loss = running_loss / len(self.train_loader)

        # Equation to update control variates requires the number of local_steps
        local_steps = len(self.train_loader) * epochs
        self.update_control_variates(local_steps)

        log(INFO, f"Performed {epochs} Epochs of Local training")
        self._handle_logging(running_loss, metrics)

        return metrics  # return final training metrics

    def validate(self, meter: Meter) -> Tuple[float, Dict[str, Scalar]]:
        self.model.eval()
        running_loss = 0.0
        meter.clear()
        with torch.no_grad():
            for input, target in self.val_loader:
                input, target = input.to(self.device), target.to(self.device)
                pred = self.model(input)
                loss = self.criterion(pred, target)

                running_loss += loss.item()
                meter.update(pred, target)

        running_loss = running_loss / len(self.val_loader)
        metrics = meter.compute()
        self._handle_logging(running_loss, metrics, is_validation=True)
        self._maybe_checkpoint(running_loss)
        return running_loss, metrics


class DPScaffoldClient(ScaffoldClient, InstanceLevelPrivacyClient):  # type: ignore
    """
    Federated Learning client for Instance Level Differentially Private Scaffold strategy

    Implemented as specified in https://arxiv.org/abs/2111.09278
    """

    def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.device) -> None:
        ScaffoldClient.__init__(self, data_path, metrics, device)
        InstanceLevelPrivacyClient.__init__(self, data_path, device)
