"""Adaptive Federated Optimization using Adam (FedAdam) strategy.

The paper can be found at [this link](https://arxiv.org/abs/2003.00295).
"""

from collections.abc import Callable, Iterable
from logging import INFO

import numpy as np
from flwr.common import (
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Scalar,
    log,
)

from repo.conf.base_schema import BaseConfig
from repo.constants import (
    FIRST_MOMENTUM,
    MODEL_PARAMETERS,
    SECOND_MOMENTUM,
)
from repo.strategy.aggregation import (
    aggregate_cumulative_average,
)
from repo.strategy.metrics import ServerMetricCallback
from repo.strategy.strategy_with_cfg import FedAvgWithConfig
from repo.utils import l2_norm


# pylint: disable=line-too-long
class FedAdam(FedAvgWithConfig):
    """FedAdam - Adaptive Federated Optimization using Adam.

    Implementation based on https://arxiv.org/abs/2003.00295v5

    Parameters
    ----------
    fraction_fit : float, optional
        Fraction of clients used during training. Defaults to 1.0.
    fraction_evaluate : float, optional
        Fraction of clients used during validation. Defaults to 1.0.
    min_fit_clients : int, optional
        Minimum number of clients used during training. Defaults to 2.
    min_evaluate_clients : int, optional
        Minimum number of clients used during validation. Defaults to 2.
    min_available_clients : int, optional
        Minimum number of total clients in the system. Defaults to 2.
    evaluate_fn : (
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                tuple[float, dict[str, Scalar]] | None,
            ]
            | None
        )
        Optional function used for validation. Defaults to None.
    on_fit_config_fn : Callable[[int], dict[str, Scalar]], optional
        Function used to configure training. Defaults to None.
    on_evaluate_config_fn : Callable[[int], dict[str, Scalar]], optional
        Function used to configure validation. Defaults to None.
    accept_failures : bool, optional
        Whether or not accept rounds containing failures. Defaults to True.
    initial_parameters : NDArrays
        Initial global model parameters.
    fit_metrics_aggregation_fn : MetricsAggregationFn | None
        Metrics aggregation function, optional.
    evaluate_metrics_aggregation_fn: MetricsAggregationFn | None
        Metrics aggregation function, optional.
    eta : float, optional
        Server-side learning rate. Defaults to 1e-1.
    beta_1 : float, optional
        Momentum parameter. Defaults to 0.9.
    beta_2 : float, optional
        Second moment parameter. Defaults to 0.95.
    tau : float, optional
        Controls the algorithm's degree of adaptability. Defaults to 1e-9.
    track_norms: bool, optional
        Flag for tracking the norms of the aggregated updates. Defaults to True.


    """

    # pylint: disable=too-many-arguments,too-many-instance-attributes,too-many-locals
    def __init__(  # noqa: PLR0913
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: (
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                tuple[float, dict[str, Scalar]] | None,
            ]
            | None
        ) = None,
        on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        accept_failures: bool = True,
        initial_parameters: NDArrays,
        fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        eta: float = 1e-1,
        beta_1: float = 0.9,
        beta_2: float = 0.95,
        tau: float = 1e-9,
        track_norms: bool = True,
        obtain_server_metrics_callback: type[ServerMetricCallback] | None = None,
        cfg: BaseConfig | None = None,
    ) -> None:
        """Federated Adam strategy.

        Parameters
        ----------
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 1.0.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 1.0.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        evaluate_fn : (
                Callable[
                    [int, NDArrays, dict[str, Scalar]],
                    tuple[float, dict[str, Scalar]] | None,
                ]
                | None
            )
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable[[int], dict[str, Scalar]], optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable[[int], dict[str, Scalar]], optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not accept rounds containing failures. Defaults to True.
        initial_parameters : NDArrays
            Initial global model parameters.
        fit_metrics_aggregation_fn : MetricsAggregationFn | None
            Metrics aggregation function, optional.
        evaluate_metrics_aggregation_fn: MetricsAggregationFn | None
            Metrics aggregation function, optional.
        eta : float, optional
            Server-side learning rate. Defaults to 1e-1.
        beta_1 : float, optional
            Momentum parameter. Defaults to 0.9.
        beta_2 : float, optional
            Second moment parameter. Defaults to 0.95.
        tau : float, optional
            Controls the algorithm's degree of adaptability. Defaults to 1e-9.
        track_norms: bool, optional
            Flag for tracking the norms of the aggregated updates. Defaults to True.
        obtain_server_metrics_callback: type[ServerMetricCallback] | None
            Optional callback for collecting server-side metrics. Defaults to None.
        cfg: BaseConfig | None
            Configuration object. Defaults to None.

        """
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

        # NOTE: This avoids translating between parameters and NDArrays every time.
        # However, it incurs in a higher memory peak. We decided to go for the previous
        # approach that uses a pointer to the parameters at the server.
        self.parameters = initial_parameters

        self.eta = eta
        self.tau = tau
        self.beta_1 = beta_1
        self.beta_2 = beta_2

        # Set state_keys variables for ease of uploading to S3
        self.state_keys = (
            MODEL_PARAMETERS,
            FIRST_MOMENTUM,
            SECOND_MOMENTUM,
        )

        # Lazy initialization
        self.momentum_vector: NDArrays = [np.zeros_like(x) for x in self.parameters]
        self.second_momentum_vector: NDArrays = [
            np.zeros_like(x) for x in self.parameters
        ]

        # Metrics tracking
        self.track_norms = track_norms
        self.obtain_server_metrics_callback = obtain_server_metrics_callback
        self.cfg = cfg

    def __repr__(self) -> str:
        """Compute a string representation of the strategy.

        Returns
        -------
        str
            String representation of the strategy.

        """
        return f"FedAdam(accept_failures={self.accept_failures})"

    def aggregate_fit(  # type: ignore[override,reportIncompatibleMethodOverride]
        self,
        server_round: int,
        results: Iterable[tuple[FitRes, NDArrays]],
        failures: Iterable[FitRes | BaseException],  # noqa: ARG002
    ) -> tuple[NDArrays | None, dict[str, Scalar]]:
        """Aggregate fit results using weighted average.

        Parameters
        ----------
        server_round : int
            Current server round.
        results : Iterable[tuple[ClientProxy, FitRes]]
            Iterable of client results.
        failures : Iterable[tuple[ClientProxy, FitRes] | BaseException]
            Iterable of failures.

        Returns
        -------
        tuple[NDArrays | None, dict[str, Scalar]]
            Aggregated parameters and metrics.

        """
        assert self.layer_names_and_types is not None, "Keys should be initialized."

        keys_to_index = {
            (key, typ): i for i, (key, typ) in enumerate(self.layer_names_and_types)
        }

        fit_metrics: list[tuple[int, dict[str, Scalar]]] = []

        def acc_metrics(
            fit_res: FitRes,
        ) -> FitRes:
            fit_metrics.append((fit_res.num_examples, fit_res.metrics))
            return fit_res

        results = (
            (acc_metrics(fitres), returned_parameters)
            for (fitres, returned_parameters) in results
        )

        metrics_aggregated: dict[str, Scalar] = {}

        metrics_callback: ServerMetricCallback | None = (
            self.obtain_server_metrics_callback(
                self,
                server_round,
            )
            if self.obtain_server_metrics_callback is not None
            else None
        )

        # Get the cumulative average of the results
        fedavg_result = aggregate_cumulative_average(
            results,
            metrics_callback=metrics_callback,
        )
        # Return None if no results were aggregated
        if fedavg_result is None:
            return None, {}

        # Initialize the metrics
        layerwise_l2_norms_pseudo_gradient: list[float] = []
        layerwise_l2_norms_momentum_vector: list[float] = []
        layerwise_l2_norms_second_momentum_vector: list[float] = []
        layerwise_l2_norms_fedavg_result: list[float] = []
        layerwise_l2_norms_model: list[float] = []
        # Loop over layer, apply the server optimizer and compute metrics
        for (layer_name, model_state_name), layer_res in fedavg_result.items():
            i = keys_to_index[layer_name, model_state_name]
            x = self.parameters[i]
            # Layer i pseudo-gradient
            layer_pseudo_gradient = x - layer_res

            # Compute first momentum of layer i
            self.momentum_vector[i] = (
                self.beta_1 * self.momentum_vector[i]
                + (1 - self.beta_1) * layer_pseudo_gradient
            )
            # Compute second momentum of layer i
            self.second_momentum_vector[i] = self.beta_2 * self.second_momentum_vector[
                i
            ] + (1 - self.beta_2) * np.multiply(
                layer_pseudo_gradient,
                layer_pseudo_gradient,
            )
            # Compute the new weights of layer i
            layer_fedadam_result = x + self.eta * np.divide(
                self.momentum_vector[i] * (1 / (1 - self.beta_1**server_round)),
                (
                    np.sqrt(
                        self.second_momentum_vector[i]
                        * (1 / (1 - self.beta_2**server_round)),
                    )
                    + self.tau
                ),
            )
            # Assign new values to the parameters variable
            self.parameters[i] = layer_fedadam_result

            # Metrics collection
            layerwise_l2_norms_pseudo_gradient.append(l2_norm([layer_pseudo_gradient]))
            layerwise_l2_norms_momentum_vector.append(
                l2_norm([self.momentum_vector[i]]),
            )
            layerwise_l2_norms_second_momentum_vector.append(
                l2_norm([self.second_momentum_vector[i]]),
            )
            layerwise_l2_norms_fedavg_result.append(
                l2_norm([x - layer_pseudo_gradient]),
            )
            layerwise_l2_norms_model.append(l2_norm([layer_fedadam_result]))

        if self.track_norms:
            metrics_aggregated |= {
                "server/l2_norm_pseudo_gradient": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_pseudo_gradient)),
                ),
                "server/l2_norm_momentum_vector": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_momentum_vector)),
                ),
                "server/l2_norm_second_momentum_vector": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_second_momentum_vector)),
                ),
                "server/l2_norm_fedavg_result": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_fedavg_result)),
                ),
                "server/l2_norm_model": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_model)),
                ),
            }
            for i, (a, b, c, d, e) in enumerate(
                zip(
                    layerwise_l2_norms_pseudo_gradient,
                    layerwise_l2_norms_momentum_vector,
                    layerwise_l2_norms_second_momentum_vector,
                    layerwise_l2_norms_fedavg_result,
                    layerwise_l2_norms_model,
                    strict=True,
                ),
            ):
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_pseudo_gradient": a}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_momentum_vector": b}
                metrics_aggregated |= {
                    f"server/layer/{i}/l2_norm_second_momentum_vector": c,
                }
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_fedavg_result": d}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_model": e}
            log(
                INFO,
                "FedAdam:"
                " l2_norm(pseudo_gradient)=%s,"
                " l2_norm(momentum_vector)=%s,"
                " l2_norm(second_momentum_vector)=%s,"
                " l2_norm(fedavg_result)=%s"
                " l2_norm(model)=%s,",
                metrics_aggregated["server/l2_norm_pseudo_gradient"],
                metrics_aggregated["server/l2_norm_momentum_vector"],
                metrics_aggregated["server/l2_norm_second_momentum_vector"],
                metrics_aggregated["server/l2_norm_fedavg_result"],
                metrics_aggregated["server/l2_norm_model"],
            )
        if metrics_callback is not None:
            metrics_aggregated |= metrics_callback.round_end(
                list(fedavg_result.values()),
            )

        return self.parameters, metrics_aggregated
