"""Federated Averaging strategy with efficient implementation.

This implementation is based on the SGD implementation in
PyTorch. It can either partially aggregate updated model parameters as soon as they
arrive and then compute the averaged pseudo-gradient or partially aggregate
pseudo-gradients while computing them. The averaged pseudo-gradient is then used to
update the global model parameters.
"""

import time
from collections.abc import Callable, Iterable
from concurrent.futures import Executor, ThreadPoolExecutor
from logging import INFO
from pathlib import Path

import numpy as np
from flwr.common import (
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Scalar,
    log,
)

from repo.conf.base_schema import BaseConfig
from repo.constants import MODEL_PARAMETERS
from repo.strategy.aggregation import (
    aggregate_cumulative_average,
    aggregate_cumulative_average_multi_processing,
)
from repo.strategy.metrics import ServerMetricCallback
from repo.strategy.strategy_with_cfg import FedAvgWithConfig
from repo.utils import ModelStateNames, l2_norm


class FedAvgEfficient(FedAvgWithConfig):
    """Configurable FedNesterov strategy implementation.

    Parameters
    ----------
    initial_parameters : NDArrays
        Initial global model parameters.
    saving_path : Path, optional
        Path to save the model parameters. Defaults to current working directory.
    fraction_fit : float, optional
        Fraction of clients used during training. Defaults to 1.0.
    fraction_evaluate : float, optional
        Fraction of clients used during validation. Defaults to 1.0.
    min_fit_clients : int, optional
        Minimum number of clients used during training. Defaults to 2.
    min_evaluate_clients : int, optional
        Minimum number of clients used during validation. Defaults to 2.
    min_available_clients : int, optional
        Minimum number of total clients in the system. Defaults to 2.
    evaluate_fn : Callable, optional
        Optional function used for validation. Defaults to None.
    on_fit_config_fn : Callable, optional
        Function used to configure training. Defaults to None.
    on_evaluate_config_fn : Callable, optional
        Function used to configure validation. Defaults to None.
    accept_failures : bool, optional
        Whether or not to accept rounds containing failures. Defaults to True.
    fit_metrics_aggregation_fn : MetricsAggregationFn, optional
        Metrics aggregation function for training. Defaults to None.
    evaluate_metrics_aggregation_fn : MetricsAggregationFn, optional
        Metrics aggregation function for evaluation. Defaults to None.
    server_learning_rate : float, optional
        Learning rate used by the server-side optimizer. Defaults to 0.7.
    track_norms : bool, optional
        Flag for tracking the norms of the aggregated updates. Defaults to True.
    obtain_server_metrics_callback : type[ServerMetricCallback], optional
        Callback for obtaining server metrics. Defaults to None.
    track_inplace_aggregation : bool, optional
        Flag for tracking the difference between standard and in-place aggregation.
        Defaults to False.
    scaling_fn : str | None, optional
        Scaling function to be used for the aggregated pseudo gradients. It can be
        None (no scaling applied), 'linear' will apply linear scaling with the
        number of clients per round, 'sqrt' scaling linearly with the square root of
        the number of clients per round. Defaults to None.
    cfg : BaseConfig, optional
        Configuration object. Defaults to None.

    """

    # pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
    def __init__(  # noqa: PLR0913
        self,
        *,
        initial_parameters: NDArrays,
        saving_path: Path | None = None,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: (
            Callable[
                [int, NDArrays, dict[str, Scalar]],
                tuple[float, dict[str, Scalar]] | None,
            ]
            | None
        ) = None,
        on_fit_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        on_evaluate_config_fn: Callable[[int], dict[str, Scalar]] | None = None,
        accept_failures: bool = True,
        fit_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        evaluate_metrics_aggregation_fn: MetricsAggregationFn | None = None,
        server_learning_rate: float = 0.7,
        track_norms: bool = True,
        obtain_server_metrics_callback: type[ServerMetricCallback] | None = None,
        scaling_fn: str | None = None,
        cfg: BaseConfig | None = None,
    ) -> None:
        """Federated Averaging strategy with efficient implementation.

        Parameters
        ----------
        initial_parameters : NDArrays
            Initial global model parameters.
        saving_path : Path, optional
            Path to save the model parameters. Defaults to current working directory.
        fraction_fit : float, optional
            Fraction of clients used during training. Defaults to 1.0.
        fraction_evaluate : float, optional
            Fraction of clients used during validation. Defaults to 1.0.
        min_fit_clients : int, optional
            Minimum number of clients used during training. Defaults to 2.
        min_evaluate_clients : int, optional
            Minimum number of clients used during validation. Defaults to 2.
        min_available_clients : int, optional
            Minimum number of total clients in the system. Defaults to 2.
        evaluate_fn : Callable, optional
            Optional function used for validation. Defaults to None.
        on_fit_config_fn : Callable, optional
            Function used to configure training. Defaults to None.
        on_evaluate_config_fn : Callable, optional
            Function used to configure validation. Defaults to None.
        accept_failures : bool, optional
            Whether or not to accept rounds containing failures. Defaults to True.
        fit_metrics_aggregation_fn : MetricsAggregationFn, optional
            Metrics aggregation function for training. Defaults to None.
        evaluate_metrics_aggregation_fn : MetricsAggregationFn, optional
            Metrics aggregation function for evaluation. Defaults to None.
        server_learning_rate : float, optional
            Learning rate used by the server-side optimizer. Defaults to 0.7.
        track_norms : bool, optional
            Flag for tracking the norms of the aggregated updates. Defaults to True.
        obtain_server_metrics_callback : type[ServerMetricCallback], optional
            Callback for obtaining server metrics. Defaults to None.
        scaling_fn : str | None, optional
            Scaling function to be used for the aggregated pseudo gradients. It can be
            None (no scaling applied), 'linear' will apply linear scaling with the
            number of clients per round, 'sqrt' scaling linearly with the square root of
            the number of clients per round. Defaults to None.
        cfg : BaseConfig, optional
            Configuration object. Defaults to None.

        Raises
        ------
        ValueError
            If the scaling function is not 'linear' or 'sqrt'.

        """
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )

        if scaling_fn is not None and scaling_fn not in {
            "linear",
            "sqrt",
        }:
            msg = "Scaling function must be either 'linear' or 'sqrt'."
            raise ValueError(msg)
        self.scaling_fn = lambda x: (
            1 if scaling_fn is None else (x if scaling_fn == "linear" else np.sqrt(x))
        )

        if saving_path is None:
            saving_path = Path(Path.cwd())
        self.saving_path = saving_path

        # Default optimizer values
        self.server_learning_rate = server_learning_rate

        # NOTE: This avoids translating between parameters and NDArrays every time.
        # However, it incurs in a higher memory peak. We decided to go for the previous
        # approach that uses a pointer to the parameters at the server.
        # ndarray_params = parameters_to_ndarrays(initial_parameters)  # noqa: ERA001
        self.parameters = initial_parameters
        # Set state_keys variables for ease of uploading to S3
        self.state_keys = (MODEL_PARAMETERS,)

        log(
            INFO,
            "Using FedAvg with server_learning_rate=%s",
            self.server_learning_rate,
        )

        self.track_norms = track_norms
        self.obtain_server_metrics_callback = obtain_server_metrics_callback
        self.cfg = cfg
        self.client_agg_executor: Executor | None = None
        self.ndarrays_agg_executor: Executor | None = None

    def aggregate_fit(  # type: ignore[override,reportIncompatibleMethodOverride]  # noqa: C901
        self,
        server_round: int,
        results: Iterable[tuple[FitRes, NDArrays]],
        failures: Iterable[FitRes | BaseException],  # noqa: ARG002
    ) -> tuple[NDArrays | None, dict[str, Scalar]]:
        """Aggregate fit results using weighted average.

        Parameters
        ----------
        server_round : int
            The current round of federated learning.
        results : Iterable[tuple[ClientProxy, FitRes]]
            The results from the clients.
        failures : Iterable[tuple[ClientProxy, FitRes] | BaseException]
            The failures from the clients.

        Returns
        -------
        tuple[NDArrays | None, dict[str, Scalar]]
            The aggregated parameters and metrics.

        """
        assert self.layer_names_and_types is not None, "Keys should be initialized."
        assert self.cfg is not None, "Config should be initialized."

        if (
            self.cfg.fl.aggregation_num_workers_across_clients != 0
            and self.client_agg_executor is None
        ):
            # If the executor is not set, initialize it
            self.client_agg_executor = ThreadPoolExecutor(
                max_workers=self.cfg.fl.aggregation_num_workers_across_clients,
            )

        if (
            self.cfg.fl.aggregation_num_workers_process_ndarrays != 0
            and self.ndarrays_agg_executor is None
        ):
            # If the executor is not set, initialize it
            self.ndarrays_agg_executor = ThreadPoolExecutor(
                max_workers=self.cfg.fl.aggregation_num_workers_process_ndarrays,
            )

        keys_to_index = {
            (key, typ): i for i, (key, typ) in enumerate(self.layer_names_and_types)
        }

        fit_metrics: list[tuple[int, dict[str, Scalar]]] = []

        def acc_metrics(
            fit_res: FitRes,
        ) -> FitRes:
            fit_metrics.append((fit_res.num_examples, fit_res.metrics))
            return fit_res

        results = (
            (acc_metrics(fitres), returned_parameters)
            for (fitres, returned_parameters) in results
        )

        metrics_aggregated: dict[str, Scalar] = {}

        metrics_callback: ServerMetricCallback | None = (
            self.obtain_server_metrics_callback(
                self,
                server_round,
            )
            if self.obtain_server_metrics_callback is not None
            else None
        )

        fedavg_result = (
            aggregate_cumulative_average_multi_processing(
                results,
                metrics_callback=metrics_callback,
                executor=self.client_agg_executor,
            )
            if self.client_agg_executor is not None
            else aggregate_cumulative_average(
                results,
                metrics_callback=metrics_callback,
                executor=self.ndarrays_agg_executor,
            )
        )
        # Return None if no results were aggregated
        if fedavg_result is None:
            return None, {}
        start_time = time.time()
        # Scale pseudo-gradient by the square root of the number of clients
        scaling_factor = self.scaling_fn(self.min_fit_clients)
        fedavg_result = {key: v * scaling_factor for key, v in fedavg_result.items()}

        # Initialize the metrics
        layerwise_l2_norms_pseudo_gradient: list[float] = []
        layerwise_l2_norms_fedavg_result: list[float] = []
        layerwise_l2_norms_model: list[float] = []

        def update_layer(
            item: tuple[tuple[str, ModelStateNames], np.ndarray],
        ) -> tuple[float, float, float]:
            # item is a tuple: ((layer_name, model_state_name), aggregated_value)
            (layer_name, model_state_name), layer_res = item
            i = keys_to_index[layer_name, model_state_name]
            # Layer i pseudo-gradient
            x = self.parameters[i]
            # Compute pseudo-gradient and new values
            layer_pseudo_gradient = x - layer_res
            layer_fedavg_result = x - self.server_learning_rate * layer_pseudo_gradient
            # Update parameter tensor
            self.parameters[i] = layer_fedavg_result
            # Return computed metrics
            return (
                l2_norm([layer_pseudo_gradient]),
                l2_norm([x - layer_pseudo_gradient]),
                l2_norm([layer_fedavg_result]),
            )

        if self.ndarrays_agg_executor is not None:
            # Schedule each layer update concurrently
            futures = [
                self.ndarrays_agg_executor.submit(update_layer, item)
                for item in fedavg_result.items()
            ]
            for future in futures:
                norm_pseudo, norm_fedavg, norm_model = future.result()
                layerwise_l2_norms_pseudo_gradient.append(norm_pseudo)
                layerwise_l2_norms_fedavg_result.append(norm_fedavg)
                layerwise_l2_norms_model.append(norm_model)
        else:
            # Fallback to serial updates
            for item in fedavg_result.items():
                norm_pseudo, norm_fedavg, norm_model = update_layer(item)
                layerwise_l2_norms_pseudo_gradient.append(norm_pseudo)
                layerwise_l2_norms_fedavg_result.append(norm_fedavg)
                layerwise_l2_norms_model.append(norm_model)

        if self.track_norms:
            metrics_aggregated |= {
                "server/l2_norm_pseudo_gradient": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_pseudo_gradient)),
                ),
                "server/l2_norm_fedavg_result": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_fedavg_result)),
                ),
                "server/l2_norm_model": np.sqrt(
                    np.sum(np.square(layerwise_l2_norms_model)),
                ),
            }
            for i, (a, c, d) in enumerate(
                zip(
                    layerwise_l2_norms_pseudo_gradient,
                    layerwise_l2_norms_fedavg_result,
                    layerwise_l2_norms_model,
                    strict=True,
                ),
            ):
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_pseudo_gradient": a}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_fedavg_result": c}
                metrics_aggregated |= {f"server/layer/{i}/l2_norm_model": d}
            log(
                INFO,
                "FedAvg:"
                " l2_norm(pseudo_gradient)=%s,"
                " l2_norm(fedavg_result)=%s"
                " l2_norm(model)=%s,",
                metrics_aggregated["server/l2_norm_pseudo_gradient"],
                metrics_aggregated["server/l2_norm_fedavg_result"],
                metrics_aggregated["server/l2_norm_model"],
            )

        if metrics_callback is not None:
            metrics_aggregated |= metrics_callback.round_end(
                list(fedavg_result.values()),
            )
        # Add the time taken to aggregate the results
        metrics_aggregated["server/strategy_aggregated_fit"] = time.time() - start_time
        return self.parameters, metrics_aggregated
