"""Handle aggregation in-place and potentially async."""

import ast
import time
from collections import defaultdict
from collections.abc import Iterable, Iterator
from concurrent.futures import Executor, Future
from functools import partial, reduce
from logging import DEBUG
from queue import Queue
from typing import Any, cast

import numpy as np
from flwr.common import FitRes, NDArrays
from flwr.common.logger import log

from repo.constants import CLIENT_STATE_ACCUMULATOR, LAYER_NAMES, LAYER_TYPES
from repo.strategy.metrics import ServerMetricCallback
from repo.utils import ModelStateNames


def _aggregate_helper(
    prev_tuple: tuple[np.ndarray, int],
    param: np.ndarray,
    num_examples: int,
) -> tuple[np.ndarray, int]:
    """Compute weighted average for one parameter.

    Parameters
    ----------
    prev_tuple : Tuple[np.ndarray, int]
        Previously accumulated parameter and count.
    param : np.ndarray
        Current parameter.
    num_examples : int
        Number of current examples.

    Returns
    -------
    Tuple[np.ndarray, int]
        Updated parameter and total sample count.

    """
    prev_params, prev_total_examples = prev_tuple
    new_total_samples = prev_total_examples + num_examples
    acc_scaling_factor = float(prev_total_examples) / new_total_samples
    scaling_factor = float(num_examples) / new_total_samples
    updated_params = prev_params * acc_scaling_factor + param * scaling_factor
    return updated_params, new_total_samples


def aggregate_parameters(
    executor: Executor | None,
    accumulator: dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]] | None,
    current: tuple[
        Iterator[np.ndarray] | Iterable[np.ndarray],
        int,
        Iterable[str],
        Iterable[ModelStateNames],
    ],
) -> dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]]:
    """Aggregate parameters in-place.

    The helper _aggregate_helper is re-used to compute each layers weighted update.
    Parallelism is enabled when an Executor is provided.

    Parameters
    ----------
    executor : Executor | None
        An externally managed Executor with a fixed number of workers.
    accumulator : dict or None
        Accumulated parameters and the total number of examples seen so far.
    current : tuple
        The current parameters, number of examples, layer names, and layer types.

    Returns
    -------
    dict
        Updated accumulator mapping.

    """
    current_params, num_examples, current_layers, current_types = current

    start_time = time.time()

    if accumulator is None:
        accumulator = {}
        for params, name, typ in zip(
            current_params,
            current_layers,
            current_types,
            strict=True,
        ):
            accumulator[name, typ] = (params, num_examples)
    elif executor is not None:
        futures: dict[tuple[str, ModelStateNames], Future[tuple[np.ndarray, int]]] = {}
        for param, name, typ in zip(
            current_params,
            current_layers,
            current_types,
            strict=True,
        ):
            key = (name, typ)
            if key in accumulator:
                futures[key] = executor.submit(
                    _aggregate_helper,
                    accumulator[key],
                    param,
                    num_examples,
                )
            else:
                accumulator[key] = (param, num_examples)
        # Collect results from futures
        for key, future in futures.items():
            accumulator[key] = future.result()
    else:
        for param, name, typ in zip(
            current_params,
            current_layers,
            current_types,
            strict=True,
        ):
            key = (name, typ)
            if key in accumulator:
                accumulator[key] = _aggregate_helper(
                    accumulator[key],
                    param,
                    num_examples,
                )
            else:
                accumulator[key] = (param, num_examples)

    del current_params

    log(
        DEBUG,
        f"Aggregated client with samples: {num_examples} "
        f"time: {time.time() - start_time}",
    )

    return accumulator


def aggregate_inplace(
    results: Iterable[
        tuple[
            Iterator[np.ndarray] | Iterable[np.ndarray],
            int,
            Iterable[str],
            Iterable[ModelStateNames],
        ]
    ],
    executor: Executor | None = None,
) -> dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]] | None:
    """Compute in-place weighted average, lazily and async.

    Parameters
    ----------
    results : Iterable[Tuple[NDArrays, int]]
        The results to aggregate.
    metrics_callback : ServerMetricCallback | None
        Optional callback for adding per-client metrics.
    executor : Executor | None
        An externally managed Executor with a fixed number of workers.
        For aggregating ndarrays

    Returns
    -------
    NDArrays | None
        The aggregated parameters.

    """
    # Holds the parameters and the total number of samples
    accumulator: dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]] | None = None
    # Define the aggregation function
    aggregation_fn = partial(aggregate_parameters, executor)
    # Aggregate the parameters
    return reduce(aggregation_fn, results, accumulator)


def aggregate_cumulative_average(
    results: Iterable[tuple[FitRes, NDArrays]],
    metrics_callback: ServerMetricCallback | None = None,
    executor: Executor | None = None,
) -> dict[tuple[str, ModelStateNames], np.ndarray] | None:
    """Aggregate parameters using cumulative weighted average.

    Process results from model fitting operations and calculate
    a weighted average of parameters based on the number of examples.

    Parameters
    ----------
    results : Iterable[tuple[FitRes, NDArrays]]
        The fit results and corresponding parameters to aggregate.
    metrics_callback : ServerMetricCallback | None
        Optional callback for adding per-client metrics.
    executor : Executor | None
        An externally managed Executor with a fixed number of workers.
        For aggregating ndarrays

    Returns
    -------
    dict[tuple[str, ModelStateNames], np.ndarray] | None
        Dictionary mapping layer names and types to aggregated parameters,
        or None if no results were processed.

    """
    # NOTE: Only one ndarray exists at a time
    # Filter res dict to remove the number of examples
    res = aggregate_inplace(
        results=(
            (
                (
                    list(
                        metrics_callback.process_per_client_results(
                            (r_p for r_p in returned_parameters),
                            fit_res.num_examples,
                        ),
                    )
                    if metrics_callback
                    else returned_parameters
                ),
                fit_res.num_examples,
                ast.literal_eval(str(fit_res.metrics[LAYER_NAMES])),
                [
                    ModelStateNames[v]
                    for v in ast.literal_eval(str(fit_res.metrics[LAYER_TYPES]))
                ],
            )
            for fit_res, returned_parameters in results
        ),
        executor=executor,
    )

    return (
        {key: value for key, (value, _num_examples) in res.items()}
        if res is not None
        else None
    )


def worker_loop(
    input_q: Queue[
        tuple[
            Iterator[np.ndarray] | Iterable[np.ndarray],
            int,
            Iterable[str],
            Iterable[ModelStateNames],
        ]
        | None
    ],
) -> dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]] | None:
    """Worker function that aggregates parameters from the shared queue.

    Returns
    -------
    dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]]
        The aggregated parameters.

    """
    accumulator = None
    while True:
        item = input_q.get()
        if item is None:
            break
        accumulator = aggregate_parameters(
            executor=None,
            accumulator=accumulator,
            current=item,
        )
    return accumulator


def merge_accumulators(
    accumulators: list[dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]]],
) -> dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]]:
    """Merge accumulators from multiple workers.

    Parameters
    ----------
    accumulators : list[dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]]]
        List of accumulators to merge.

    Returns
    -------
    dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]]
        Merged accumulators.

    """
    merged: dict[tuple[str, ModelStateNames], tuple[np.ndarray, int]] = {}
    for acc in accumulators:
        for key, (value, count) in acc.items():
            if key in merged:
                prev_value, prev_count = merged[key]
                new_total = prev_count + count
                scaling_prev = prev_count / new_total
                scaling_curr = count / new_total
                merged_value = prev_value * scaling_prev + value * scaling_curr
                merged[key] = (merged_value, new_total)
            else:
                merged[key] = (value, count)
    return merged


def aggregate_cumulative_average_multi_processing(
    results: Iterable[tuple[FitRes, NDArrays]],
    metrics_callback: ServerMetricCallback | None,
    executor: Executor,
) -> dict[tuple[str, ModelStateNames], np.ndarray] | None:
    """Compute in-place weighted average using a given Executor+shared queue.

    As results come in from the generator, they are enqueued into a shared queue.
    Worker processes pull tasks from this queue using
    worker_loop (which receives the executor as an argument).
    After all tasks are enqueued, a sentinel (None) is sent to each worker.
    The returned accumulators are merged via merge_accumulators.

    Parameters
    ----------
    results : Iterable[tuple[ClientProxy, FitRes]]
        Generator yielding client results.
    metrics_callback : ServerMetricCallback | None
        Optional callback for per-client metrics.
    executor : Executor
        An externally managed Executor with a fixed number of workers.

    Returns
    -------
    dict[tuple[str, ModelStateNames], np.ndarray] | None
        Final aggregated parameters mapping (layer_name, ModelStateNames) to np.ndarray.

    """
    input_q = cast(
        """Queue[tuple[Iterator[np.ndarray] | Iterable[np.ndarray], int,
        Iterable[str], Iterable[ModelStateNames]] | None]""",
        Queue(),
    )
    assert hasattr(
        executor,
        "_max_workers",
    ), """Executor must have a '_max_workers' attribute
        to determine the number of workers."""
    # Use the executor's max_workers (accessed via its private attribute)
    num_workers = executor._max_workers  # type: ignore[reportAttributeAccessIssue]  # noqa: SLF001

    futures = [executor.submit(worker_loop, input_q) for _ in range(num_workers)]

    for fit_res, returned_parameters in results:
        current = (
            (
                list(
                    metrics_callback.process_per_client_results(
                        returned_parameters,
                        fit_res.num_examples,
                    ),
                )
                if metrics_callback
                else returned_parameters
            ),
            fit_res.num_examples,
            ast.literal_eval(str(fit_res.metrics[LAYER_NAMES])),
            [
                ModelStateNames[v]
                for v in ast.literal_eval(str(fit_res.metrics[LAYER_TYPES]))
            ],
        )
        input_q.put(current)

    for _ in range(num_workers):
        input_q.put(None)

    worker_accs = [res for future in futures if (res := future.result()) is not None]

    if not worker_accs:
        return None

    merged_acc = merge_accumulators(worker_accs)
    return {key: value for key, (value, _count) in merged_acc.items()}


def weighted_average(
    metrics: list[tuple[int, dict]],
) -> dict:
    """Compute a weighted average over pre-defined metrics.

    Parameters
    ----------
    metrics : List[Tuple[int, Dict]]
        The metrics to aggregate.

    Returns
    -------
    Dict
        The weighted average over pre-defined metrics.

    """
    client_state_accumulator: dict[int | str, dict[str, Any]] = {}
    # Filter out empty metrics
    metrics = [
        (num_examples, metric)
        for num_examples, metric in metrics
        if metric != {}
    ]
    total_num_examples = sum(
        num_examples for num_examples, _ in metrics
    )
    weighted_metrics: dict = defaultdict(float)

    for num_examples, metric in metrics:
        cid = metric.pop("cid", None)
        client_state = metric.pop("client_state", None)
        client_state_acc = metric.pop(CLIENT_STATE_ACCUMULATOR, None)
        for key, value in metric.items():
            if not isinstance(value, str):
                weighted_metrics[key] += num_examples * value
        if cid is not None and client_state is not None:
            client_state_accumulator[cid] = ast.literal_eval(client_state)
        if client_state_acc is not None:
            client_state_accumulator |= ast.literal_eval(client_state_acc)

    ret_dict = {
        key: value / total_num_examples for key, value in weighted_metrics.items()
    }
    if client_state_accumulator:
        ret_dict |= {CLIENT_STATE_ACCUMULATOR: str(client_state_accumulator)}

    return ret_dict
