"""Federated Averaging with Nestorov Momentum strategy.

This implementation is based on FedMom from [this paper](https://arxiv.org/abs/2002.02090).
It can either partially aggregate updated model parameters as soon as they arrive and the compute
the averaged pseudo-gradient or partially aggregate pseudo-gradients while computing them.
The averaged pseudo-gradient is then used to update the global model parameters.
"""

from collections.abc import Callable, Iterable
from copy import deepcopy
from logging import INFO
from pathlib import Path

import numpy as np
from flwr.common import (
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Scalar,
    log,
)

from repo.conf.base_schema import BaseConfig
from repo.constants import FIRST_MOMENTUM, MODEL_PARAMETERS
from repo.strategy.aggregation import (
    aggregate_cumulative_average,
)
from repo.strategy.metrics import ServerMetricCallback
from repo.strategy.strategy_with_cfg import FedAvgWithConfig
from repo.utils import (
    l2_norm,
)


# flake8: noqa: E501
class FedMom(FedAvgWithConfig):
    """Federated Averaging with Nestorov Momentum strategy.

    Implementation based on FedMom from [this paper](https://arxiv.org/abs/2002.02090).

    Parameters
    ----------
    fraction_fit : float, optional
        Fraction of clients used during training. In case `min_fit_clients`
        is larger than `fraction_fit * available_clients`, `min_fit_clients`
        will still be sampled. Defaults to 1.0.
    fraction_evaluate : float, optional
        Fraction of clients used during validation. In case `min_evaluate_clients`
        is larger than `fraction_evaluate * available_clients`, `min_evaluate_clients`
        will still be sampled. Defaults to 1.0.
    min_fit_clients : int, optional
        Minimum number of clients used during training. Defaults to 2.
    min_evaluate_clients : int, optional
        Minimum number of clients used during validation. Defaults to 2.
    min_available_clients : int, optional
        Minimum number of total clients in the system. Defaults to 2.
    evaluate_fn : (
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                tuple[float, dict[str, Scalar]] | None,
            ]
            | None
        )
        Optional function used for validation. Defaults to None.
    on_fit_config_fn : Callable[[int], dict[str, Scalar]], optional
        Function used to configure training. Defaults to None.
    on_evaluate_config_fn : Callable[[int], dict[str, Scalar]], optional
        Function used to configure validation. Defaults to None.
    accept_failures : bool, optional
        Whether or not accept rounds containing failures. Defaults to True.
    initial_parameters : NDArrays, optional
        Initial global model parameters.
    fit_metrics_aggregation_fn : MetricsAggregationFn | None
        Metrics aggregation function, optional.
    evaluate_metrics_aggregation_fn : MetricsAggregationFn | None
        Metrics aggregation function, optional.
    seed : int, optional
        Seed for reproducibility. Defaults to 1337.
    server_learning_rate : float, optional
        Learning rate used by the server-side optimizer. Defaults to 0.7.
    server_momentum: float, optional
        Momentum coefficient used by the server-side optimizer. Defaults to 0.9.
    track_norms: bool, optional
        Flag for tracking the norms of the aggregated updates. Defaults to True.
    obtain_server_metrics_callback: ServerMetricCallback, optional
        Callback for collecting metrics on the server side. Defaults to None.
    cfg : BaseConfig, optional
        Configuration object. Defaults to None.

    """

    # pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
    def __init__(  # noqa: PLR0913, D107
        self,
        *,
        initial_parameters: NDArrays,
        saving_path: Path | None = None,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: (
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                tuple[float, dict[str, Scalar]] | None,
            ]
            | None
        ) = None,
        on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        accept_failures: bool = True,
        fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        server_learning_rate: float = 0.7,
        server_momentum: float = 0.9,
        track_norms: bool = True,
        obtain_server_metrics_callback: type[ServerMetricCallback] | None = None,
        cfg: BaseConfig | None = None,
    ) -> None:
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        if saving_path is None:
            saving_path = Path(Path.cwd())
        self.saving_path = saving_path

        # Default optimizer values
        self.server_learning_rate = server_learning_rate
        self.server_momentum = server_momentum

        # NOTE: This avoids translating between parameters and NDArrays every time. However, it incurs in a higher memory peak. We decided to go for the previous approach that uses a pointer to the parameters at the server.
        # ndarray_params = parameters_to_ndarrays(initial_parameters)  # noqa: ERA001
        self.parameters = initial_parameters

        # Set state_keys variables for ease of uploading to S3
        self.state_keys = (
            MODEL_PARAMETERS,
            FIRST_MOMENTUM,
        )

        log(
            INFO,
            "Using Nesterov Momentum (FedMom) with server_learning_rate=%s and"
            " server_momentum=%s",
            self.server_learning_rate,
            self.server_momentum,
        )
        self.momentum_vector: NDArrays = deepcopy(initial_parameters)

        self.track_norms = track_norms
        self.cfg = cfg

        self.obtain_server_metrics_callback = obtain_server_metrics_callback

    def aggregate_fit(  # type: ignore[override,reportIncompatibleMethodOverride]
        self,
        server_round: int,
        results: Iterable[tuple[FitRes, NDArrays]],
        failures: Iterable[FitRes | BaseException],  # noqa: ARG002
    ) -> tuple[NDArrays | None, dict[str, Scalar]]:
        """Aggregate fit results using weighted average.

        Parameters
        ----------
        server_round : int
            Current server round.
        results : Iterable[tuple[ClientProxy, FitRes]]
            Iterable of tuples containing the client and the fit result.
        failures : Iterable[tuple[ClientProxy, FitRes] | BaseException]
            Iterable of tuples containing the client and the fit result or an exception.

        Returns
        -------
        tuple[NDArrays | None, dict[str, Scalar]]
            The aggregated parameters and the metrics.

        """
        assert self.parameters is not None, (
            "When using server-side optimization, model needs to be initialized."
        )

        assert self.layer_names_and_types is not None, "Keys should be initialized."

        keys_to_index = {
            (key, typ): i for i, (key, typ) in enumerate(self.layer_names_and_types)
        }

        fit_metrics: list[tuple[int, dict[str, Scalar]]] = []

        def acc_metrics(
            fit_res: FitRes,
        ) -> FitRes:
            fit_metrics.append((fit_res.num_examples, fit_res.metrics))
            return fit_res

        results = ((acc_metrics(fitres), returned_parameters) for (fitres, returned_parameters) in results)

        metrics_aggregated: dict[str, Scalar] = {}

        metrics_callback: ServerMetricCallback | None = (
            self.obtain_server_metrics_callback(
                self,
                server_round,
            )
            if self.obtain_server_metrics_callback is not None
            else None
        )

        fedavg_result = aggregate_cumulative_average(
            results,
            metrics_callback=metrics_callback,
        )
        # Return None if no results were aggregated
        if fedavg_result is None:
            return None, {}

        # Initialize the metrics
        layerwise_l2_norms_pseudo_gradient: list[float] = []
        layerwise_l2_norms_momentum_vector: list[float] = []
        layerwise_l2_norms_fedavg_result: list[float] = []
        layerwise_l2_norms_model: list[float] = []
        # Loop over layer, apply the server optimizer and compute metrics
        for (layer_name, model_state_name), layer_res in fedavg_result.items():
            i = keys_to_index[layer_name, model_state_name]
            x = self.parameters[i]
            # Layer i pseudo-gradient
            layer_pseudo_gradient = x - layer_res

            # Using FedMom implementation
            # Layer i new momentum vector
            layer_new_momentum_vector = (
                x - self.server_learning_rate * layer_pseudo_gradient
            )
            # Layer i new values
            layer_fedavgm_result = (
                1 + self.server_momentum
            ) * layer_new_momentum_vector - self.server_momentum * self.momentum_vector[
                i
            ]
            # Assign new values to the parameters variable
            self.parameters[i] = layer_fedavgm_result
            # Assign new values to the momentum vector
            self.momentum_vector[i] = layer_new_momentum_vector
            # Metrics collection
            layerwise_l2_norms_pseudo_gradient.append(l2_norm([layer_pseudo_gradient]))
            layerwise_l2_norms_momentum_vector.append(
                l2_norm([layer_new_momentum_vector]),
            )
            layerwise_l2_norms_fedavg_result.append(l2_norm([x]))
            layerwise_l2_norms_model.append(l2_norm([layer_fedavgm_result]))

        if self.track_norms:
            metrics_aggregated |= {
                "server/l2_norm_pseudo_gradient": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_pseudo_gradient)),
                ),
                "server/l2_norm_momentum_vector": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_momentum_vector)),
                ),
                "server/l2_norm_fedavg_result": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_fedavg_result)),
                ),
                "server/l2_norm_model": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_model)),
                ),
            }
            for i, (a, b, c, d) in enumerate(
                zip(
                    layerwise_l2_norms_pseudo_gradient,
                    layerwise_l2_norms_momentum_vector,
                    layerwise_l2_norms_fedavg_result,
                    layerwise_l2_norms_model,
                    strict=True,
                ),
            ):
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_pseudo_gradient": a}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_momentum_vector": b}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_fedavg_result": c}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_model": d}
            log(
                INFO,
                "Nesterov Momentum:"
                " l2_norm(pseudo_gradient)=%s,"
                " l2_norm(self.momentum_vector)=%s,"
                " l2_norm(fedavg_result)=%s"
                " l2_norm(model)=%s,",
                metrics_aggregated["server/l2_norm_pseudo_gradient"],
                metrics_aggregated["server/l2_norm_momentum_vector"],
                metrics_aggregated["server/l2_norm_fedavg_result"],
                metrics_aggregated["server/l2_norm_model"],
            )

        if metrics_callback is not None:
            metrics_aggregated |= metrics_callback.round_end(
                list(fedavg_result.values()),
            )

        return self.parameters, metrics_aggregated
