"""Federated Averaging with Nesterov Momentum strategy.

This implementation is based on the SGD with Nesterov Momentum implementation in
PyTorch. It can either partially aggregate updated model parameters as soon as they
arrive and then compute the averaged pseudo-gradient or partially aggregate
pseudo-gradients while computing them. The averaged pseudo-gradient is then used to
update the global model parameters.
"""

from collections.abc import Callable, Iterable
from logging import INFO
from pathlib import Path

import numpy as np
from flwr.common import (
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Scalar,
    log,
)

from repo.conf.base_schema import BaseConfig
from repo.constants import FIRST_MOMENTUM, MODEL_PARAMETERS
from repo.strategy.aggregation import (
    aggregate_cumulative_average,
)
from repo.strategy.metrics import ServerMetricCallback
from repo.strategy.strategy_with_cfg import FedAvgWithConfig
from repo.utils import (
    l2_norm,
)


class FedNesterov(FedAvgWithConfig):
    """Federated Averaging with Nesterov Momentum strategy.

    Parameters
    ----------
    initial_parameters : NDArrays
        Initial global model parameters.
    saving_path : Path, optional
        Path to save the model parameters. Defaults to current working directory.
    fraction_fit : float, optional
        Fraction of clients used during training. Defaults to 1.0.
    fraction_evaluate : float, optional
        Fraction of clients used during validation. Defaults to 1.0.
    min_fit_clients : int, optional
        Minimum number of clients used during training. Defaults to 2.
    min_evaluate_clients : int, optional
        Minimum number of clients used during validation. Defaults to 2.
    min_available_clients : int, optional
        Minimum number of total clients in the system. Defaults to 2.
    evaluate_fn : Callable, optional
        Optional function used for validation. Defaults to None.
    on_fit_config_fn : Callable, optional
        Function used to configure training. Defaults to None.
    on_evaluate_config_fn : Callable, optional
        Function used to configure validation. Defaults to None.
    accept_failures : bool, optional
        Whether or not to accept rounds containing failures. Defaults to True.
    fit_metrics_aggregation_fn : MetricsAggregationFn, optional
        Metrics aggregation function for training. Defaults to None.
    evaluate_metrics_aggregation_fn : MetricsAggregationFn, optional
        Metrics aggregation function for evaluation. Defaults to None.
    server_learning_rate : float, optional
        Learning rate used by the server-side optimizer. Defaults to 0.7.
    server_momentum : float, optional
        Momentum coefficient used by the server-side optimizer. Defaults to 0.9.
    track_norms : bool, optional
        Flag for tracking the norms of the aggregated updates. Defaults to True.
    obtain_server_metrics_callback : type[ServerMetricCallback], optional
        Callback for obtaining server metrics. Defaults to None.
    track_inplace_aggregation : bool, optional
        Flag for tracking the difference between standard and in-place aggregation.
        Defaults to False.
    scaling_fn : str | None, optional
        Scaling function to be used for the aggregated pseudo gradients. It can be
        None (no scaling applied), 'linear' will apply linear scaling with the
        number of clients per round, 'sqrt' scaling linearly with the square root of
        the number of clients per round. Defaults to None.
    cfg : BaseConfig, optional
        Configuration object. Defaults to None.

    """

    # pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
    def __init__(  # noqa: PLR0913
        self,
        *,
        initial_parameters: NDArrays,
        saving_path: Path | None = None,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: (
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                tuple[float, dict[str, Scalar]] | None,
            ]
            | None
        ) = None,
        on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        accept_failures: bool = True,
        fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        server_learning_rate: float = 0.7,
        server_momentum: float = 0.9,
        track_norms: bool = True,
        obtain_server_metrics_callback: type[ServerMetricCallback] | None = None,
        scaling_fn: str | None = None,
        cfg: BaseConfig | None = None,
    ) -> None:
        """Initialize the FedNesterov strategy.

        Parameters
        ----------
        initial_parameters : NDArrays
            Initial global model parameters.
        saving_path : Path, optional
            Path to save the model parameters. Defaults to current working directory.
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 1.0.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 1.0.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        evaluate_fn : Callable, optional
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable, optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable, optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not to accept rounds containing failures. Defaults to True.
        fit_metrics_aggregation_fn : MetricsAggregationFn, optional
            Metrics aggregation function for training. Defaults to None.
        evaluate_metrics_aggregation_fn : MetricsAggregationFn, optional
            Metrics aggregation function for evaluation. Defaults to None.
        server_learning_rate : float, optional
            Learning rate used by the server-side optimizer. Defaults to 0.7.
        server_momentum : float, optional
            Momentum coefficient used by the server-side optimizer. Defaults to 0.9.
        track_norms : bool, optional
            Flag for tracking the norms of the aggregated updates. Defaults to True.
        obtain_server_metrics_callback : type[ServerMetricCallback], optional
            Callback for obtaining server metrics. Defaults to None.
        scaling_fn : str | None, optional
            Scaling function to be used for the aggregated pseudo gradients. It can be
            None (no scaling applied), 'linear' will apply linear scaling with the
            number of clients per round, 'sqrt' scaling linearly with the square root of
            the number of clients per round. Defaults to None.
        cfg : BaseConfig, optional
            Configuration object. Defaults to None.

        """
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

        if scaling_fn is not None:
            assert scaling_fn in {
                "linear",
                "sqrt",
            }, "Scaling function must be either 'linear' or 'sqrt'."
        self.scaling_fn = lambda x: (
            1 if scaling_fn is None else (x if scaling_fn == "linear" else np.sqrt(x))
        )

        if saving_path is None:
            saving_path = Path(Path.cwd())
        self.saving_path = saving_path

        # Default optimizer values
        self.server_learning_rate = server_learning_rate
        self.server_momentum = server_momentum

        # NOTE: This avoids translating between parameters and NDArrays every time.
        # However, it incurs in a higher memory peak. We decided to go for the previous
        # approach that uses a pointer to the parameters at the server.
        self.parameters = initial_parameters

        # Set state_keys variables for ease of uploading to S3
        self.state_keys = (
            MODEL_PARAMETERS,
            FIRST_MOMENTUM,
        )

        log(
            INFO,
            "Using Nesterov Momentum with server_learning_rate=%s and"
            " server_momentum=%s",
            self.server_learning_rate,
            self.server_momentum,
        )
        # Lazy initialization
        self.momentum_vector: NDArrays = [np.zeros_like(x) for x in self.parameters]

        self.track_norms = track_norms
        self.obtain_server_metrics_callback = obtain_server_metrics_callback
        self.cfg = cfg

    def aggregate_fit(  # type: ignore[override,reportIncompatibleMethodOverride]
        self,
        server_round: int,
        results: Iterable[tuple[FitRes, NDArrays]],
        failures: Iterable[FitRes | BaseException],  # noqa: ARG002
    ) -> tuple[NDArrays | None, dict[str, Scalar]]:
        """Aggregate fit results using weighted average.

        Parameters
        ----------
        server_round : int
            The current round of federated learning.
        results : Iterable[tuple[ClientProxy, FitRes]]
            The results from the clients.
        failures : Iterable[tuple[ClientProxy, FitRes] | BaseException]
            The failures from the clients.

        Returns
        -------
        tuple[NDArrays | None, dict[str, Scalar]]
            The aggregated parameters and metrics.

        """
        assert self.parameters is not None, (
            "When using server-side optimization, model needs to be initialized."
        )

        assert self.layer_names_and_types is not None, "Keys should be initialized."

        keys_to_index = {
            (key, typ): i for i, (key, typ) in enumerate(self.layer_names_and_types)
        }

        fit_metrics: list[tuple[int, dict[str, Scalar]]] = []

        def acc_metrics(
            fit_res: FitRes,
        ) -> FitRes:
            fit_metrics.append((fit_res.num_examples, fit_res.metrics))
            return fit_res

        results = (
            (acc_metrics(fitres), returned_parameters)
            for (fitres, returned_parameters) in results
        )

        metrics_aggregated: dict[str, Scalar] = {}

        metrics_callback: ServerMetricCallback | None = (
            self.obtain_server_metrics_callback(
                self,
                server_round,
            )
            if self.obtain_server_metrics_callback is not None
            else None
        )

        fedavg_result = aggregate_cumulative_average(
            results,
            metrics_callback=metrics_callback,
        )
        # Return None if no results were aggregated
        if fedavg_result is None:
            return None, {}
        # Scale pseudo-gradient by the square root of the number of clients
        scaling_factor = self.scaling_fn(self.min_fit_clients)
        fedavg_result = {key: v * scaling_factor for key, v in fedavg_result.items()}

        # Initialize the metrics
        layerwise_l2_norms_pseudo_gradient: list[float] = []
        layerwise_l2_norms_momentum_vector: list[float] = []
        layerwise_l2_norms_fedavg_result: list[float] = []
        layerwise_l2_norms_model: list[float] = []
        # Loop over layer, apply the server optimizer and compute metrics
        for (layer_name, model_state_name), layer_res in fedavg_result.items():
            i = keys_to_index[layer_name, model_state_name]
            x = self.parameters[i]
            # Layer i pseudo-gradient
            layer_pseudo_gradient = x - layer_res

            # Using torch.optim.SGD implementation
            # Compute momentum vector
            self.momentum_vector[i] = (
                self.server_momentum * self.momentum_vector[i] + layer_pseudo_gradient
            )
            # Applying Nesterov momentum
            layer_pseudo_gradient = x - (layer_res)
            # Layer i new values
            layer_fednestorov_result = (
                x - self.server_learning_rate * layer_pseudo_gradient
            )

            # Assign new values to the parameters variable
            self.parameters[i] = layer_fednestorov_result
            # Metrics collection
            layerwise_l2_norms_pseudo_gradient.append(l2_norm([layer_pseudo_gradient]))
            layerwise_l2_norms_momentum_vector.append(
                l2_norm([self.momentum_vector[i]]),
            )
            layerwise_l2_norms_fedavg_result.append(
                l2_norm([x - layer_pseudo_gradient]),
            )
            layerwise_l2_norms_model.append(l2_norm([layer_fednestorov_result]))

        if self.track_norms:
            metrics_aggregated |= {
                "server/l2_norm_pseudo_gradient": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_pseudo_gradient)),
                ),
                "server/l2_norm_momentum_vector": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_momentum_vector)),
                ),
                "server/l2_norm_fedavg_result": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_fedavg_result)),
                ),
                "server/l2_norm_model": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_model)),
                ),
            }
            for i, (a, b, c, d) in enumerate(
                zip(
                    layerwise_l2_norms_pseudo_gradient,
                    layerwise_l2_norms_momentum_vector,
                    layerwise_l2_norms_fedavg_result,
                    layerwise_l2_norms_model,
                    strict=True,
                ),
            ):
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_pseudo_gradient": a}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_momentum_vector": b}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_fedavg_result": c}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_model": d}
            log(
                INFO,
                "FedNestorov:"
                " l2_norm(pseudo_gradient)=%s,"
                " l2_norm(self.momentum_vector)=%s,"
                " l2_norm(fedavg_result)=%s"
                " l2_norm(model)=%s,",
                metrics_aggregated["server/l2_norm_pseudo_gradient"],
                metrics_aggregated["server/l2_norm_momentum_vector"],
                metrics_aggregated["server/l2_norm_fedavg_result"],
                metrics_aggregated["server/l2_norm_model"],
            )

        if metrics_callback is not None:
            metrics_aggregated |= metrics_callback.round_end(
                list(fedavg_result.values()),
            )

        return self.parameters, metrics_aggregated
