from logging import WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union

from flwr.server.strategy import FedAvg
from flwr.common import (
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

import numpy as np
from functools import reduce


def cal_differ(layer, past_layer, ratio):
    if layer.ndim<2:
        return layer*ratio
    res = np.sign(np.maximum(layer-past_layer, 0)) * 2 - 1
    return res

def cal_update(layer, past_layer, eta):
    if layer.ndim<2:
        return layer
    res =  eta*(np.sign(np.maximum(layer, 0)) * 2 - 1) + past_layer
    return res

class FedSign(FedAvg):
    def __init__(self,*,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 1.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 2,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[[int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],]] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        eta = 0.001) -> None:

        super().__init__(fraction_fit=fraction_fit, fraction_evaluate=fraction_evaluate, min_fit_clients=min_fit_clients, 
            min_evaluate_clients=min_evaluate_clients, min_available_clients=min_available_clients, 
            evaluate_fn=evaluate_fn, on_fit_config_fn=on_fit_config_fn, on_evaluate_config_fn=on_evaluate_config_fn, 
            accept_failures=accept_failures, initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn, evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn)
        self.eta = eta
        self.past_parameters = parameters_to_ndarrays(initial_parameters)

    def aggregate(self, results: List[Tuple[NDArrays, int]]) -> NDArrays:
        num_examples_total = sum([num_examples for _, num_examples in results])
        weighted_weights = [
            [cal_differ(layer, self.past_parameters[i], num_examples / num_examples_total) for i, layer in enumerate(weights)] for weights, num_examples in results
        ]

        weights_prime: NDArrays = [
            reduce(np.add, layer_updates) 
            for layer_updates in zip(*weighted_weights)
        ]

        final_weights = [
            cal_update(layer, self.past_parameters[i], self.eta) for i, layer in enumerate(weights_prime)
        ]
        return final_weights

    def aggregate_fit(self, server_round: int, results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]]) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        if not results: return None, {}
        if not self.accept_failures and failures: return None, {}

        weights_results = [(parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples) for _, fit_res in results]

        updated_weights = self.aggregate(weights_results)
        self.past_parameters = updated_weights
        parameters_aggregated = ndarrays_to_parameters(updated_weights)

        # Aggregate custom metrics if aggregation fn was provided
        metrics_aggregated = {}
        if self.fit_metrics_aggregation_fn:
            fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
            metrics_aggregated = self.fit_metrics_aggregation_fn(fit_metrics)
        elif server_round == 1:  # Only log this warning once
            log(WARNING, "No fit_metrics_aggregation_fn provided")

        return parameters_aggregated, metrics_aggregated

