from typing import Callable, Dict, List, OrderedDict, Optional
import json
import numpy as np
import flwr as fl
import torch
import torch.nn as nn
from flwr.common import Scalar
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from collections import OrderedDict
import logging
from logging import ERROR, INFO
from flwr.common.logger import log, set_logger_propagation
from flwr.common import Context, ParametersRecord, RecordSet
from flwr.common import array_from_numpy, Array
from flwr.common.typing import NDArray

from models import test, train_fedavg
from utils_theory import cosine_eta
from utils_helper import cosine_eta_warmfloor
from utils_collect_res import compute_metrics_before_send, append_jsonl, _comm_bits_from_payload

# pylint: disable=too-many-instance-attributes
class FlowerClientFedAvg(fl.client.NumPyClient):
    """Flower client implementing FedAvg."""

    # pylint: disable=too-many-arguments
    def __init__(
        self,
        net: torch.nn.Module,
        trainloader: DataLoader,
        valloader: DataLoader,
        device: torch.device,
        num_epochs: int,
        learning_rate: float,
        momentum: float,
        weight_decay: float,
        approach: str,
        sparsify_by: float,  # C2: fraction for main update
        context: Context,
        comp_type: str = "topk",
        H: int = 5,
        reset_frac: float = 0.10,  # C1: fraction for error-reset packets
    ) -> None:
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.device = device
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.approach = approach
        self.sparsify_by = sparsify_by
        self.context = context
        self.client_state = (context.state)
        self.node_id = (context.node_id)
        self.comp_type = comp_type
        self.H = H
        self.reset_frac = reset_frac
        self._pre_record_cache = None
        self._pre_metrics_cache = None

    def get_parameters(self, config: Dict[str, Scalar], parameters=None):
        """Return the current local model parameters."""
        if not config:
            # Here we are in "zeroth" round, the server is requesting the parameters from any client for the first time
            return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
        else:
            trainable_names = {name for name, _ in self.net.named_parameters()}
            vec_parts = []
            for (name, _), arr in zip(self.net.state_dict().items(), parameters):
                if name in trainable_names:
                    vec_parts.append(torch.as_tensor(arr, dtype=torch.float32).flatten())
            flat_prev = torch.cat(vec_parts, dim=0).to(self.device)
            flat_curr = torch.nn.utils.parameters_to_vector([p.data for p in self.net.parameters() if p.requires_grad]).to(self.device)

            flat_delta = flat_curr - flat_prev                            # local update (dense)
            n = flat_delta.numel()

            # Load residual e_t from storage (sparse → dense)
            e_t = torch.zeros_like(flat_delta)
            if self.client_state.parameters_records is not None and "prev_round" in self.client_state.parameters_records:
                rec = self.client_state.parameters_records["prev_round"]
                idx = torch.as_tensor(rec["indices"].numpy(), dtype=torch.long, device=self.device)
                val = torch.as_tensor(rec["values"].numpy(),  dtype=torch.float32, device=self.device)
                e_t[idx] = val

            # ------------- C₂: main update channel (Top-k / Rand-k) -------------
            if self.comp_type == "topk":
                k2 = max(1, int(self.sparsify_by * n))
                _, upd_idx = torch.topk(torch.abs(flat_delta), k2)
            elif self.comp_type == "randk":
                k2 = max(1, int(self.sparsify_by * n))
                upd_idx = torch.randperm(n, device=self.device)[:k2]
            else:
                raise ValueError(f"Unknown comp_type: {self.comp_type}")

            upd_val = flat_delta[upd_idx]
            sent_update_dense = torch.zeros_like(flat_delta); sent_update_dense[upd_idx] = upd_val
            dropped_update = flat_delta - sent_update_dense               # what C_2 dropped this round

            e_half = e_t + dropped_update

            # ------------- Reset branch (every H rounds) -------------
            reset_round = (config["server_round"] % self.H == 0)

            e_idx = np.array([], dtype=np.int64)
            e_val = np.array([], dtype=np.float32)

            if reset_round:
                k1 = max(1, int(self.reset_frac * n))
                # choose indices by magnitude of e_half
                _, err_idx = torch.topk(torch.abs(e_half), k1)
                err_val = e_half[err_idx]

                # locally remove the part we will send now
                e_local_sent_dense = torch.zeros_like(e_half); e_local_sent_dense[err_idx] = err_val
                e_next = e_half - e_local_sent_dense

                # remember what we sent (for reconciliation when e_bar arrives)
                rec_sent = ParametersRecord()
                rec_sent["indices"] = array_from_numpy(err_idx.detach().cpu().numpy())
                rec_sent["values"]  = array_from_numpy(err_val.detach().cpu().numpy())
                self.client_state.parameters_records["last_e_sent"] = rec_sent

                # payload
                e_idx = err_idx.detach().cpu().numpy().astype(np.int64)
                e_val = err_val.detach().cpu().numpy().astype(np.float32)
            else:
                e_next = e_half

            # ------------- Store e_next (sparse) -------------
            nz = torch.nonzero(e_next, as_tuple=False).squeeze(1)
            rec = ParametersRecord()
            rec["indices"] = array_from_numpy(nz.detach().cpu().numpy())
            rec["values"]  = array_from_numpy(e_next[nz].detach().cpu().numpy())
            self.client_state.parameters_records["prev_round"] = rec

            # ------------- Prepare payload -------------
            # BN stats (unchanged)
            bn_mu, bn_var = [], []
            for m in self.net.modules():
                if isinstance(m, torch.nn.BatchNorm2d):
                    bn_mu.append(m.running_mean)
                    bn_var.append(m.running_var)
            bn_mu  = torch.cat(bn_mu ).cpu().numpy().astype(np.float16)
            bn_var = torch.cat(bn_var).cpu().numpy().astype(np.float16)
            count  = np.array([len(self.trainloader.dataset)], dtype=np.float32)

            delta_idx = upd_idx.detach().cpu().numpy().astype(np.int64)
            delta_val = upd_val.detach().cpu().numpy().astype(np.float32)

            # IMPORTANT: keep order stable; append e-packet only on reset rounds.
            if reset_round:
                flower_payload = [delta_idx, delta_val, bn_mu, bn_var, count, e_idx, e_val]
            else:
                flower_payload = [delta_idx, delta_val, bn_mu, bn_var, count]
            
            # Compute comm bits from the actual payload and write the JSON record
            try:
                bits = _comm_bits_from_payload(flower_payload, include_bn=True)
                # Update cached record and write
                if hasattr(self, "_pre_record_cache"):
                    rec = dict(self._pre_record_cache)
                    rec.update(bits)
                    append_jsonl(f"logs/clients_metrics/{self.approach}/metrics_client_{self.node_id}.jsonl", rec)
                    # also stash for fit(...) to attach to Flower metrics
                    if hasattr(self, "_pre_metrics_cache"):
                        self._pre_metrics_cache["uplink_bits_total"] = float(bits["uplink_bits_total"])
                        self._metrics_for_flower = dict(self._pre_metrics_cache)
            except Exception as ex:
                # don’t crash training on logging issues
                print(f"[warn] metrics logging failed: {ex}")
                
            return flower_payload
    
    def set_parameters(self, parameters, config):
        """Set the local model parameters using given ones."""
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict()

        for k, v in params_dict:
            try:
                if k.endswith('num_batches_tracked'):
                    tensor = torch.tensor(v, dtype=torch.long)
                else:
                    tensor = torch.Tensor(v)
                state_dict[k] = tensor
            except Exception as e:
                print(f"Error converting parameter {k}: {e}")

        if not state_dict:
            raise ValueError("State dictionary is empty!")

        try:
            self.net.load_state_dict(state_dict, strict=True)
        except Exception as e:
            print(f"Error loading state dict with strict=True: {e}")
            self.net.load_state_dict(state_dict, strict=False)
    
    def fit(self, parameters, config: Dict[str, Scalar]):
        """Implement distributed fit function for a given client."""
        
        self.set_parameters(parameters, config)
        # cosine scheduling for LR
        eta_max = 0.01
        eta_min = 0.003
        learning_rate = cosine_eta_warmfloor(
            r=config["server_round"], total_R=400,
            eta_max=eta_max, eta_min=eta_min, warmup_R=10
        )
        flat_global = torch.nn.utils.parameters_to_vector([p.data for p in self.net.parameters() if p.requires_grad]).to(self.device)

        # Parse JSON lists safely (server sends "[]" when empty)
        e_bar_idx_list = []
        e_bar_val_list = []
        try:
            e_bar_idx_list = json.loads(config.get("e_bar_idx", "[]"))
            e_bar_val_list = json.loads(config.get("e_bar_val", "[]"))
        except Exception:
            # if anything goes wrong, treat as empty broadcast
            e_bar_idx_list, e_bar_val_list = [], []

        has_e_bar = len(e_bar_idx_list) > 0

        # Only reconcile if (i) server broadcasted a non-empty e_bar, and
        # (ii) THIS client actually sent an error packet last round.
        if has_e_bar and self.client_state is not None and self.client_state.parameters_records is not None:
            rec_last = self.client_state.parameters_records.get("last_e_sent", None)
            if rec_last is not None:
                # --- densify e_bar on device ---
                e_bar = torch.zeros_like(flat_global, device=self.device)
                e_bar_idx_t = torch.as_tensor(e_bar_idx_list, dtype=torch.long, device=self.device)
                e_bar_val_t = torch.as_tensor(e_bar_val_list, dtype=torch.float32, device=self.device)
                e_bar[e_bar_idx_t] = e_bar_val_t

                # --- reconstruct what THIS client sent last round (sparse -> dense) ---
                sent = torch.zeros_like(flat_global, device=self.device)
                sidx = torch.as_tensor(rec_last["indices"].numpy(), dtype=torch.long, device=self.device)
                sval = torch.as_tensor(rec_last["values"].numpy(),  dtype=torch.float32, device=self.device)
                sent[sidx] = sval

                # --- load current residual e (prev_round) (sparse -> dense) ---
                rec_prev = self.client_state.parameters_records.get("prev_round", None)
                e = torch.zeros_like(flat_global, device=self.device)
                if rec_prev is not None:
                    eidx = torch.as_tensor(rec_prev["indices"].numpy(), dtype=torch.long, device=self.device)
                    eval = torch.as_tensor(rec_prev["values"].numpy(),  dtype=torch.float32, device=self.device)
                    e[eidx] = eval

                # --- reconcile: e <- e + (sent - e_bar) ---
                e = e + (sent - e_bar)

                # --- write reconciled e back sparsely and clear last_e_sent ---
                nz = torch.nonzero(e, as_tuple=False).squeeze(1)
                rec = ParametersRecord()
                rec["indices"] = array_from_numpy(nz.detach().cpu().numpy())
                rec["values"]  = array_from_numpy(e[nz].detach().cpu().numpy())
                self.client_state.parameters_records["prev_round"] = rec
                # after successful reconciliation
                self.client_state.parameters_records.pop("last_e_sent", None)



        # --- build e_t (dense) after reconciliation ---
        if self.client_state.parameters_records is not None and "prev_round" in self.client_state.parameters_records:
            rec = self.client_state.parameters_records["prev_round"]
            idx = torch.as_tensor(rec["indices"].numpy(), dtype=torch.long, device=self.device)
            val = torch.as_tensor(rec["values"].numpy(),  dtype=torch.float32, device=self.device)
            e_t = torch.zeros_like(flat_global); e_t[idx] = val
        else:
            e_t = torch.zeros_like(flat_global)
            if config["server_round"] > 1 and self.client_state is not None and self.client_state.parameters_records is not None and "prev_round" in self.client_state.parameters_records: # no residual in round 1
                rec = self.client_state.parameters_records["prev_round"]
                idx = torch.from_numpy(rec["indices"].numpy()).long().to(self.device)
                val = torch.from_numpy(rec["values" ].numpy()).to(self.device)
                e_t = torch.zeros_like(flat_global)
                e_t[idx] = val                                    # dense residual
            else:
                e_t        = torch.zeros_like(flat_global)

        # Gradient mismatch probe α:
        alpha_preview = float(getattr(self, "alpha_r", 0.0))   # for CSER/EF: 0.0 ; for SA-PEF: >0

        # Pre-compute "pre-send" metrics (no payload yet, so bits will be filled later)
        pre_record, pre_metrics = compute_metrics_before_send(
            model=self.net,
            device=self.device,
            valloader=self.valloader,
            flat_w_r=flat_global.detach().clone(),
            e_t_dense=e_t.detach().clone(),
            alpha_preview=alpha_preview,
            include_bn_in_bits=False,
            server_round=int(config["server_round"]),
            sparsify_by=self.sparsify_by,
            learning_rate=learning_rate,
            num_epochs=self.num_epochs,
            comp_type=self.comp_type,
            L_est=1.0,
            payload_preview=None,
        )
        # Temporarily stash; we'll update the 'bits' part later and write JSON
        self._pre_record_cache = pre_record
        self._pre_metrics_cache = pre_metrics

        # write xₜ½ back into the model so local SGD starts from it
        torch.nn.utils.vector_to_parameters(
            flat_global, [p for p in self.net.parameters() if p.requires_grad]
        )

        train_fedavg(
            self.net,
            self.trainloader,
            self.device,
            self.num_epochs,
            learning_rate,
            self.momentum,
            self.weight_decay,
        )
        final_p_np = self.get_parameters(config, parameters)
        metrics = getattr(self, "_metrics_for_flower", {})
        return final_p_np, len(self.trainloader.dataset), metrics

    def evaluate(self, parameters, config):
        print("Evaluating model with received parameters.")
        try:
            self.set_parameters(parameters, config)
        except Exception as e:
            print(f"Error in set_parameters: {e}")
            raise e
        print("Parameters set successfully, proceeding to evaluation.")
        loss, acc = test(self.net, self.valloader, self.device)
        return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}

# pylint: disable=too-many-arguments
def gen_client_fn(
    trainloaders: List[DataLoader],
    valloaders: List[DataLoader],
    num_epochs: int,
    learning_rate: float,
    model: DictConfig,
    momentum: float = 0.9,
    weight_decay: float = 1e-5,
    approach: str = "random",
    sparsify_by: float = 0.5,
    comp_type: str = "topk",
    H: int = 5,
    reset_frac: float = 0.10, 
) -> Callable[[str], FlowerClientFedAvg]:  # pylint: disable=too-many-arguments
    """Generate the client function that creates the FedAvg flower clients.

    Parameters
    ----------
    trainloaders: List[DataLoader]
        A list of DataLoaders, each pointing to the dataset training partition
        belonging to a particular client.
    valloaders: List[DataLoader]
        A list of DataLoaders, each pointing to the dataset validation partition
        belonging to a particular client.
    num_epochs : int
        The number of local epochs each client should run the training for before
        sending it to the server.
    learning_rate : float
        The learning rate for the SGD optimizer of clients.
    momentum : float
        The momentum for SGD optimizer of clients
    weight_decay : float
        The weight decay for SGD optimizer of clients

    Returns
    -------
    Callable[[str], FlowerClientFedAvg]
        The client function that creates the FedAvg flower clients
    """

    def client_fn(context:Context): # -> FlowerClientFedAvg
        """Create a Flower client representing a single organization."""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        net = instantiate(model).to(device)

        # Note: each client gets a different trainloader/valloader, so each client
        # will train and evaluate on their own unique data
        trainloader = trainloaders[int(context.node_config["partition-id"])]
        valloader = valloaders[int(context.node_config["partition-id"])]

        return FlowerClientFedAvg(
            net,
            trainloader,
            valloader,
            device,
            num_epochs,
            learning_rate,
            momentum,
            weight_decay,
            approach,
            sparsify_by,
            context,
            comp_type,
            H,
            reset_frac,
        ).to_client()

    return client_fn
