from logging import INFO, WARNING
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import os
import json
import numpy as np
import torch
from collections import OrderedDict, defaultdict

from flwr.common import (
    EvaluateIns,
    EvaluateRes,
    FitIns,
    FitRes,
    MetricsAggregationFn,
    NDArrays,
    Parameters,
    Scalar,
    ndarrays_to_parameters,
    parameters_to_ndarrays,
)
from flwr.common.logger import log
from flwr.server.client_manager import ClientManager
from flwr.server.client_proxy import ClientProxy

from flwr.server.strategy import FedAvg

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Setting `min_available_clients` lower than `min_fit_clients` or
`min_evaluate_clients` can cause the server to fail when there are too few clients
connected to the server. `min_available_clients` must be set to a value larger
than or equal to the values of `min_fit_clients` and `min_evaluate_clients`.
"""

class CSER(FedAvg):
    """FedAvg + CSER resets: add averaged error packets every H rounds."""

    # pylint: disable=too-many-arguments,too-many-instance-attributes,line-too-long
    def __init__(
        self,
        *,
        fraction_fit: float = 1.0,
        fraction_evaluate: float = 0.0,
        min_fit_clients: int = 2,
        min_evaluate_clients: int = 0,
        min_available_clients: int = 2,
        evaluate_fn: Optional[
            Callable[[int, NDArrays, Dict[str, Scalar]], Optional[Tuple[float, Dict[str, Scalar]]]]
        ] = None,
        on_fit_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        on_evaluate_config_fn: Optional[Callable[[int], Dict[str, Scalar]]] = None,
        accept_failures: bool = True,
        initial_parameters: Optional[Parameters] = None,
        fit_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        evaluate_metrics_aggregation_fn: Optional[MetricsAggregationFn] = None,
        approach: str = "random",
        net: torch.nn.Module = None,
        H: int = 5,  # reset period
        comp_type: str = "topk",
        sparsify_by: float = 0.01,
        partitioning: str = "iid",
        clients_per_round: int = 10,
        num_clients: int = 100,
    ) -> None:
        super().__init__(
            fraction_fit=fraction_fit,
            fraction_evaluate=fraction_evaluate,
            min_fit_clients=min_fit_clients,
            min_evaluate_clients=min_evaluate_clients,
            min_available_clients=min_available_clients,
            evaluate_fn=evaluate_fn,
            on_fit_config_fn=on_fit_config_fn,
            on_evaluate_config_fn=on_evaluate_config_fn,
            accept_failures=accept_failures,
            initial_parameters=initial_parameters,
            fit_metrics_aggregation_fn=fit_metrics_aggregation_fn,
            evaluate_metrics_aggregation_fn=evaluate_metrics_aggregation_fn,
        )
        self.approach = approach
        self.net = net
        self.H = H
        self.comp_type = comp_type
        self.sparsify_by = sparsify_by
        self.partitioning = partitioning
        self.clients_per_round = clients_per_round
        self.num_clients = num_clients
        # Will hold (e_bar_idx: List[int], e_bar_val: List[float]) to send next round
        self._last_e_bar: Optional[Tuple[List[int], List[float]]] = None

    def __repr__(self) -> str:
        return f"CSER(H={self.H}, accept_failures={self.accept_failures})"

    # ------------------------- CONFIGURE FIT -------------------------
    def _base_fit_config(self, server_round: int) -> Dict[str, Scalar]:
        cfg: Dict[str, Scalar] = {}
        if self.on_fit_config_fn is not None:
            cfg = self.on_fit_config_fn(server_round)
        # Always include round number
        cfg.update({"server_round": server_round})

        # Inject e_bar (from prior reset round), JSON-encoded for Flower Scalar
        if self._last_e_bar is not None:
            e_idx, e_val = self._last_e_bar
            cfg["e_bar_idx"] = json.dumps(e_idx)             # List[int]
            cfg["e_bar_val"] = json.dumps([float(v) for v in e_val])
            self._last_e_bar = None  # consume after broadcasting
        else:
            cfg["e_bar_idx"] = json.dumps([])                # empty by default
            cfg["e_bar_val"] = json.dumps([])

        return cfg

    def configure_fit(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, FitIns]]:
        """Configure next round; also broadcast e_bar if available."""
        cfg = self._base_fit_config(server_round)
        fit_ins = FitIns(parameters, cfg)

        # Sample clients
        sample_size, min_num_clients = self.num_fit_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)

        log(INFO, f"Configuring fit with {len(clients)} clients")
        # log(INFO, f"Used the config file: {cfg}")
        return [(client, fit_ins) for client in clients]

    # ------------------------- CONFIGURE EVAL ------------------------
    def configure_evaluate(
        self, server_round: int, parameters: Parameters, client_manager: ClientManager
    ) -> List[Tuple[ClientProxy, EvaluateIns]]:
        if self.fraction_evaluate == 0.0:
            return []
        cfg = {}
        if self.on_evaluate_config_fn is not None:
            cfg = self.on_evaluate_config_fn(server_round)
        evaluate_ins = EvaluateIns(parameters, cfg)

        sample_size, min_num_clients = self.num_evaluation_clients(client_manager.num_available())
        clients = client_manager.sample(num_clients=sample_size, min_num_clients=min_num_clients)
        return [(client, evaluate_ins) for client in clients]

    # ------------------------- AGGREGATE FIT ------------------------
    def aggregate_fit(
        self,
        server_round: int,
        results: List[Tuple[ClientProxy, FitRes]],
        failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]],
    ) -> Tuple[Optional[Parameters], Dict[str, Scalar]]:
        """Aggregate updates; on reset rounds also assimilate averaged error packets."""
        if not results:
            return None, {}
        if not self.accept_failures and failures:
            return None, {}

        print("Aggregating fit results")

        # Decode client payloads as NDArrays
        weights_results = [
            (parameters_to_ndarrays(fit_res.parameters), fit_res.num_examples)
            for _, fit_res in results
        ]

        # Load previous checkpoint (existing path)
        save_dir = "server_model_checkpoints_"+self.approach+"_"+self.comp_type+"_"+str(self.sparsify_by)+"_"+self.partitioning+"_"+str(self.num_clients)+"_"+str(self.clients_per_round)
        os.makedirs(save_dir, exist_ok=True)
        load_checkpoint_path = f"{save_dir}/model_checkpoint_round_{server_round-1}.pth"
        self.net.load_state_dict(torch.load(load_checkpoint_path))

        with torch.no_grad():
            device = next(self.net.parameters()).device
            vec_global = torch.nn.utils.parameters_to_vector(
                [p.data for p in self.net.parameters() if p.requires_grad]
            ).to(device)

            agg_update = torch.zeros_like(vec_global)  
            agg_count  = torch.zeros_like(vec_global)  
            BN_len = sum(m.running_mean.numel() for m in self.net.modules()
                         if isinstance(m, torch.nn.BatchNorm2d))
            mu_sum  = torch.zeros(BN_len, device=device)
            ex2_sum = torch.zeros(BN_len, device=device)
            total_w = 0.0

            # --- collect error-packets for reset rounds ---
            e_packets: List[Tuple[np.ndarray, np.ndarray]] = []

            # --------- aggregate main updates and BN stats ----------
            for payload, _ in weights_results:
                # Expect 5 arrays; on reset rounds some clients will send 7
                idx, val, mu, var, cnt = payload[:5]
                w = float(cnt[0])

                idx_t = torch.from_numpy(idx).to(device=device, dtype=torch.long)
                val_t = torch.from_numpy(val).to(device=device, dtype=torch.float32)

                agg_update[idx_t] += val_t * w
                agg_count [idx_t] += w

                if mu.size:  # central BN path
                    mu_t  = torch.from_numpy(mu ).to(device).float()
                    var_t = torch.from_numpy(var).to(device).float()
                    mu_sum  += mu_t  * w
                    ex2_sum += (var_t + mu_t.pow(2)) * w

                total_w += w

                # Optional error packet present?
                if len(payload) >= 7:
                    e_idx_np, e_val_np = payload[5], payload[6]
                    if e_idx_np.size > 0:
                        e_packets.append((e_idx_np, e_val_np))

            # Average per-index and apply to model (η = 1)
            mask = agg_count > 0
            agg_update[mask] /= agg_count[mask]
            vec_global[mask] += agg_update[mask]

            # --------- CSER: on reset rounds, ASSIMILATE averaged error ---------
            if server_round % self.H == 0 and len(e_packets) > 0:
                n = vec_global.numel()
                e_acc = torch.zeros(n, device=device, dtype=torch.float32)
                for e_idx_np, e_val_np in e_packets:
                    ei = torch.from_numpy(e_idx_np).to(device=device, dtype=torch.long)
                    ev = torch.from_numpy(e_val_np).to(device=device, dtype=torch.float32)
                    e_acc[ei] += ev
                e_acc /= float(len(e_packets))  # uniform average across senders

                vec_global += e_acc  # ASSIMILATION

                # Store sparse e_bar to broadcast next round (JSON lists)
                nz = torch.nonzero(e_acc, as_tuple=False).squeeze(1)
                e_bar_idx = nz.detach().cpu().numpy().astype(np.int64).tolist()
                e_bar_val = e_acc[nz].detach().cpu().numpy().astype(np.float32).tolist()
                self._last_e_bar = (e_bar_idx, e_bar_val)

            # Write weights back
            torch.nn.utils.vector_to_parameters(
                vec_global, [p for p in self.net.parameters() if p.requires_grad]
            )

            # Aggregate BN stats centrally (unchanged)
            if total_w > 0:
                mu_bar  = mu_sum  / total_w
                var_bar = (ex2_sum / total_w) - mu_bar.pow(2)
                var_bar.clamp_(min=1e-3)
                ptr = 0
                for m in self.net.modules():
                    if isinstance(m, torch.nn.BatchNorm2d):
                        c = m.running_mean.numel()
                        m.running_mean.copy_(mu_bar [ptr:ptr+c])
                        m.running_var .copy_(var_bar[ptr:ptr+c])
                        ptr += c

        # Save checkpoint for the next round
        torch.save(self.net.state_dict(), f"{save_dir}/model_checkpoint_round_{server_round}.pth")

        parameters_aggregated = ndarrays_to_parameters(
            [v.cpu().numpy() for v in self.net.state_dict().values()]
        )

        def _mean_of(key, fit_metrics):
            vals, weights = [], []
            for n, m in fit_metrics:
                if key in m:
                    vals.append(float(m[key])); weights.append(n)
            if not vals: return None
            return float(np.average(vals, weights=weights))

        # Metrics aggregation (unchanged)
        fit_metrics = [(res.num_examples, res.metrics) for _, res in results]
        
        metrics_aggregated = {
            "grad_norm_sq": _mean_of("grad_norm_sq", fit_metrics),
            "residual_energy": _mean_of("residual_energy", fit_metrics),
            "grad_mismatch_sq": _mean_of("grad_mismatch_sq", fit_metrics),
            "rho_r": _mean_of("rho_r", fit_metrics),
            "uplink_bits_total": _mean_of("uplink_bits_total", fit_metrics),
        }

        return parameters_aggregated, metrics_aggregated

    def evaluate(
        self, server_round: int, parameters: Parameters
    ) -> Optional[Tuple[float, Dict[str, Scalar]]]:
        """Evaluate model parameters using an evaluation function."""
        if self.evaluate_fn is None:
            # No evaluation function provided
            return None
        parameters_ndarrays = parameters_to_ndarrays(parameters)
        eval_res = self.evaluate_fn(server_round, parameters_ndarrays, {})
        if eval_res is None:
            return None
        loss, metrics = eval_res
        return loss, metrics

    