from logging import INFO, WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import os
from collections import OrderedDict, defaultdict

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from flwr.server.strategy import FedAvg
from functools import reduce

import numpy as np
from utils_helper import segment_resnet_parameters, reconstruct_parameters, flatten_resnet_parameters

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

# pylint: disable=line-too-long
class FedSpars(FedAvg):
    """Federated Averaging strategy with partial model sharing.
    """

    # pylint: disable=too-many-arguments,too-many-instance-attributes, line-too-long
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 0.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 0,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[
                [int, NDArrays, Dict[str, Scalar]],
                Optional[Tuple[float, Dict[str, Scalar]]],
            ]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        approach: str = "random",
        comp_type: str = "topk",
        sparsify_by: float = 0.01,
        partitioning: str = "iid",
        clients_per_round: int = 10,
        num_clients: int = 100,
        net: torch.nn.Module = None,
        # proximal_mu: float,
    ) -> None:
        super().__init__(
            fraction_fit = fraction_fit,
            fraction_evaluate = fraction_evaluate,
            min_fit_clients = min_fit_clients,
            min_evaluate_clients = min_evaluate_clients,
            min_available_clients = min_available_clients,
            evaluate_fn = evaluate_fn,
            on_fit_config_fn = on_fit_config_fn,
            on_evaluate_config_fn = on_evaluate_config_fn,
            accept_failures = accept_failures,
            initial_parameters = initial_parameters,
            fit_metrics_aggregation_fn = fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn = evaluate_metrics_aggregation_fn,
        )
        self.approach = approach
        self.comp_type = comp_type
        self.sparsify_by = sparsify_by
        self.partitioning = partitioning
        self.clients_per_round = clients_per_round
        self.num_clients = num_clients
        self.net = net  
        # self.proximal_mu = proximal_mu

    def __repr__(self) -> str:
        """Compute a string representation of the strategy."""
        rep = f"FedSpars(accept_failures={self.accept_failures})"
        return rep

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure the next round of training.
        Sends the segment ID to the clients
        """

        if server_round == 1:
            print("First round of training")
            config = {}
            if self.on_fit_config_fn is not None:
                # Custom fit config function provided
                config = self.on_fit_config_fn(server_round)
            config.update({
                "server_round": server_round
            })
            fit_ins = FitIns(parameters, config)

            # Sample clients
            sample_size, min_num_clients = self.num_fit_clients(
                client_manager.num_available()
            )
            clients = client_manager.sample(
                num_clients=sample_size, min_num_clients=min_num_clients
            )
            log(INFO, f"Configuring fit with {len(clients)} clients")
            log(INFO, f"Used the config file: {config}")

            # Return client/config pairs
            return [(client, fit_ins) for client in clients]
        

        # The following code is for the subsequent rounds of training
        config = {}
        if self.on_fit_config_fn is not None:
            # Custom fit config function provided
            config = self.on_fit_config_fn(server_round)
         # Update config with the necessary parameters for the round
        config.update({
            "server_round": server_round
        })

        fit_ins = FitIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)

        log(INFO, f"Configuring fit with {len(clients)} clients")
        log(INFO, f"Used the config file: {config}")

        # Return client/config pairs
        return [(client, fit_ins) for client in clients]

       

    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        """Configure the next round of evaluation."""
        # Do not configure federated evaluation if fraction eval is 0.
        if self.fraction_evaluate == 0.0:
            return []

        # Parameters and config
        config = {}
        if self.on_evaluate_config_fn is not None:
            # Custom evaluation config function provided
            config = self.on_evaluate_config_fn(server_round)
        evaluate_ins = EvaluateIns(parameters, config)

        # Sample clients
        sample_size, min_num_clients = self.num_evaluation_clients(
            client_manager.num_available()
        )
        clients = client_manager.sample(
            num_clients=sample_size, min_num_clients=min_num_clients
        )

        # Return client/config pairs
        return [(client, evaluate_ins) for client in clients ]
    
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate fit results using weighted average."""
        if not results:
            return None, {}
        # Do not aggregate if there are failures and failures are not accepted
        if not self.accept_failures and failures:
            return None, {}
        
        print("Aggregating fit results")
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        save_dir = "server_model_checkpoints_"+self.approach+"_"+self.comp_type+"_"+str(self.sparsify_by)+"_"+self.partitioning+"_"+str(self.num_clients)+"_"+str(self.clients_per_round)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        load_checkpoint_path = f"{save_dir}/model_checkpoint_round_{server_round-1}.pth"
        self.net.load_state_dict(torch.load(load_checkpoint_path))
        
        with torch.no_grad():
            # --- init -------------------------------------------------------
            device = next(self.net.parameters()).device
            vec_global = torch.nn.utils.parameters_to_vector(
                [p.data for p in self.net.parameters() if p.requires_grad]
            ).to(device)

            agg_update = torch.zeros_like(vec_global)   
            agg_count  = torch.zeros_like(vec_global)  

            BN_len = sum(m.running_mean.numel() for m in self.net.modules()
                        if isinstance(m, torch.nn.BatchNorm2d))
            mu_sum  = torch.zeros(BN_len, device=device)
            ex2_sum = torch.zeros(BN_len, device=device)
            total_w = 0.0

            # --- aggregate --------------------------------------------------
            for payload, _ in weights_results:
                idx, val, mu, var, cnt = payload
                w = float(cnt[0])

                idx_t = torch.from_numpy(idx).to(device=device, dtype=torch.long)
                val_t = torch.from_numpy(val).to(device=device, dtype=torch.float32)

                agg_update[idx_t] += val_t * w
                agg_count [idx_t] += w

                if mu.size:                      # central-BN path
                    mu_t  = torch.from_numpy(mu ).to(device).float()
                    var_t = torch.from_numpy(var).to(device).float()
                    mu_sum  += mu_t  * w
                    ex2_sum += (var_t + mu_t.pow(2)) * w

                total_w += w

            mask = agg_count > 0
            agg_update[mask] /= agg_count[mask]
            vec_global[mask] += 1.0 * agg_update[mask]    # η = 1

            torch.nn.utils.vector_to_parameters(
                vec_global, [p for p in self.net.parameters() if p.requires_grad]
            )

            if total_w > 0:
                mu_bar  = mu_sum  / total_w
                var_bar = (ex2_sum / total_w) - mu_bar.pow(2)
                var_bar.clamp_(min=1e-3)

                ptr = 0
                for m in self.net.modules():
                    if isinstance(m, torch.nn.BatchNorm2d):
                        c = m.running_mean.numel()
                        m.running_mean.copy_(mu_bar [ptr:ptr+c])
                        m.running_var .copy_(var_bar[ptr:ptr+c])
                        ptr += c

        # save checkpoint for the next round
        torch.save(self.net.state_dict(),
                f"{save_dir}/model_checkpoint_round_{server_round}.pth")

        parameters_aggregated = ndarrays_to_parameters(
            [v.cpu().numpy() for v in self.net.state_dict().values()]
        )

        def _mean_of(key, fit_metrics):
            vals, weights = [], []
            for n, m in fit_metrics:
                if key in m:
                    vals.append(float(m[key])); weights.append(n)
            if not vals: return None
            return float(np.average(vals, weights=weights))

        fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
        
        metrics_aggregated = {
            "grad_norm_sq": _mean_of("grad_norm_sq", fit_metrics),
            "residual_energy": _mean_of("residual_energy", fit_metrics),
            "grad_mismatch_sq": _mean_of("grad_mismatch_sq", fit_metrics),
            "rho_r": _mean_of("rho_r", fit_metrics),
            "uplink_bits_total": _mean_of("uplink_bits_total", fit_metrics),
        }

        return parameters_aggregated, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    