import numpy as np
from functools import partial, reduce
from typing import Any, Callable, Union

from flwr.common import FitRes, NDArray, NDArrays, parameters_to_ndarrays
from flwr.server.client_proxy import ClientProxy


def aggregate(results: list[tuple[NDArrays, int]]) -> NDArrays:
    """Compute weighted average."""
    # Calculate the total number of examples used during training
    num_examples_total = sum(num_examples for (_, num_examples) in results)

    # Create a list of weights, each multiplied by the related number of examples
    weighted_weights = [
        [layer * num_examples for layer in weights] for weights, num_examples in results
    ]

    # Compute average weights of each layer
    weights_prime: NDArrays = [
        reduce(np.add, layer_updates) / num_examples_total
        for layer_updates in zip(*weighted_weights)
    ]
    return weights_prime


def aggregate_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
    """Compute in-place weighted average."""
    # Count total examples
    num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)

    # Compute scaling factors for each result
    scaling_factors = np.asarray(
        [fit_res.num_examples / num_examples_total for _, fit_res in results]
    )

    # Let's do in-place aggregation
    # Get first result, then add up each other
    params = [
        _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
        for x in parameters_to_ndarrays(results[0][1].parameters)
    ]

    for i, (_, fit_res) in enumerate(results[1:], start=1):
        res = (
            _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
            for x in parameters_to_ndarrays(fit_res.parameters)
        )
        params = [
            reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
            for layer_updates in zip(params, res)
        ]

    return params


def aggregate_median(results: list[tuple[NDArrays, int]]) -> NDArrays:
    """Compute median."""
    # Create a list of weights and ignore the number of examples
    weights = [weights for weights, _ in results]

    # Compute median weight of each layer
    median_w: NDArrays = [
        np.median(np.asarray(layer), axis=0) for layer in zip(*weights)
    ]
    return median_w


# ====================================================================================================== #

def aggregate_masked_inplace(results: list[tuple[ClientProxy, FitRes]], alpha) -> NDArrays:
    """Compute in-place weighted average."""
    # Count total examples
    num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
    num_masked_total = sum(fit_res.metrics['masked_samples'] for (_, fit_res) in results)

    # Compute scaling factors for each result
    t_scaling_factors = np.asarray(
        [fit_res.num_examples / num_examples_total for _, fit_res in results]
    )
    m_scaling_factors = np.asarray(
        [fit_res.metrics['masked_samples'] / num_masked_total for _, fit_res in results]
    )
    
    scaling_factors = alpha * m_scaling_factors + (1 - alpha) * t_scaling_factors

    # Let's do in-place aggregation
    # Get first result, then add up each other
    params = [
        _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
        for x in parameters_to_ndarrays(results[0][1].parameters)
    ]

    for i, (_, fit_res) in enumerate(results[1:], start=1):
        res = (
            _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
            for x in parameters_to_ndarrays(fit_res.parameters)
        )
        params = [
            reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
            for layer_updates in zip(params, res)
        ]

    return params


def aggregate_classifier_inplace(results: list[tuple[ClientProxy, FitRes]], cls_idx) -> NDArrays:
    """Compute in-place weighted average."""
    # Count total examples
    num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
    num_masked_total = sum(fit_res.metrics['masked_samples'] for (_, fit_res) in results)

    # Compute scaling factors for each result
    t_scaling_factors = np.asarray(
        [fit_res.num_examples / num_examples_total for _, fit_res in results]
    )
    m_scaling_factors = np.asarray(
        [fit_res.metrics['masked_samples'] / num_masked_total for _, fit_res in results]
    )

    # Let's do in-place aggregation
    params = [
        _try_inplace(x, m_scaling_factors[0] if i in cls_idx else t_scaling_factors[0], np.multiply)
        for i, x in enumerate(parameters_to_ndarrays(results[0][1].parameters))
    ]

    for i, (_, fit_res) in enumerate(results[1:], start=1):
        res = (
            _try_inplace(x, m_scaling_factors[i] if j in cls_idx else t_scaling_factors[i], np.multiply)
            for j, x in enumerate(parameters_to_ndarrays(fit_res.parameters))
        )
        params = [
            reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
            for layer_updates in zip(params, res)
        ]

    return params


def aggregate_inp_inplace(results: list[tuple[ClientProxy, FitRes]]) -> NDArrays:
    """Compute in-place weighted average."""
    # Count total examples
    num_examples_total = sum(fit_res.num_examples for (_, fit_res) in results)
    num_masked_total = sum(fit_res.metrics['in_p_mean'] for (_, fit_res) in results)

    # Compute scaling factors for each result
    if num_masked_total == 0:
        mask_name = 'in_mask'
        scaling_factors = np.asarray(
            [fit_res.num_examples / num_examples_total for _, fit_res in results]
        )
    else:
        mask_name = 'in_p_mean_mask'
        scaling_factors = np.asarray(
            [fit_res.metrics['in_p_mean'] / num_masked_total for _, fit_res in results]
        )
    
    print(f"{mask_name}: {scaling_factors}")

    params = [
        _try_inplace(x, scaling_factors[0], np_binary_op=np.multiply)
        for x in parameters_to_ndarrays(results[0][1].parameters)
    ]

    for i, (_, fit_res) in enumerate(results[1:], start=1):
        res = (
            _try_inplace(x, scaling_factors[i], np_binary_op=np.multiply)
            for x in parameters_to_ndarrays(fit_res.parameters)
        )
        params = [
            reduce(partial(_try_inplace, np_binary_op=np.add), layer_updates)
            for layer_updates in zip(params, res)
        ]

    return params

# ====================================================================================================== #

def _try_inplace(x: NDArray, y: Union[NDArray, np.float64], np_binary_op: np.ufunc) -> NDArray:
    return (  # type: ignore[no-any-return]
        np_binary_op(x, y, out=x)
        if np.can_cast(y, x.dtype, casting="same_kind")
        else np_binary_op(x, np.array(y, x.dtype), out=x)
    )