from typing import Callable, Dict, List, OrderedDict, Optional
import json
import numpy as np
import flwr as fl
import torch
import torch.nn as nn
from flwr.common import Scalar
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.utils.data import DataLoader
from models import test, train_fedavg
from collections import OrderedDict
import logging
from logging import ERROR, INFO
from flwr.common.logger import log, set_logger_propagation
from flwr.common import Context, ParametersRecord, RecordSet
from flwr.common import array_from_numpy, Array
from flwr.common.typing import NDArray


# Add this import
from utils_helper import cosine_eta_warmfloor, reconstruct_parameters, segment_resnet_parameters, flatten_resnet_parameters, ndarray_to_array, basic_array_deserialisation
from utils_theory import cosine_eta, alpha_with_contraction_floor, check_descent_coupling, alpha_with_contraction_floor_updated, alpha_policy, dense_from_sparse
from utils_collect_res import compute_metrics_before_send, append_jsonl, _comm_bits_from_payload
# pylint: disable=too-many-instance-attributes
class FlowerClientFedAvg(fl.client.NumPyClient):
    """Flower client implementing FedAvg."""

    # pylint: disable=too-many-arguments
    def __init__(
        self,
        net: torch.nn.Module,
        trainloader: DataLoader,
        valloader: DataLoader,
        device: torch.device,
        num_epochs: int,
        learning_rate: float,
        momentum: float,
        weight_decay: float,
        approach: str,
        sparsify_by: float,
        context: Context,
        comp_type: str = "topk",
        alpha_r: float = 0.5
    ) -> None:
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader
        self.device = device
        self.num_epochs = num_epochs
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.approach = approach
        self.sparsify_by = sparsify_by
        self.context = context
        self.client_state = (context.state)
        self.node_id = (context.node_id)
        self.comp_type = comp_type
        self.alpha_r = alpha_r
        self._pre_record_cache = None
        self._pre_metrics_cache = None

    def get_parameters(self, config: Dict[str, Scalar], parameters=None):
        """Return the current local model parameters."""
        # return [val.cpu().numpy() for _, val in self.net.state_dict().items()]

        spars_indices = None
        if not config:
            # Here we are in "zeroth" round, the server is requesting the parameters from any client for the first time
            return [val.cpu().numpy() for _, val in self.net.state_dict().items()]
        else:
            trainable_names = {name for name, _ in self.net.named_parameters()}
            vec_parts = []
            for (name, _), arr in zip(self.net.state_dict().items(), parameters):
                if name in trainable_names:                 # buffer? → skip
                    vec_parts.append(torch.as_tensor(arr, dtype=torch.float32).flatten())

            flat_prev_state_dict = torch.cat(vec_parts, dim=0)          # θ_t  (1-D)

            flat_state_dict = torch.nn.utils.parameters_to_vector(
                [p.data for p in self.net.parameters() if p.requires_grad]
            )
            # The below code is for sparsification, We will select the indices of the parameters to be sent to the server
            flat_delta_parameters = flat_state_dict.to(self.device) - flat_prev_state_dict.to(self.device)
            

            # Get the previous state dict from the context
            if config["server_round"] > 1 and self.client_state is not None and self.client_state.parameters_records is not None and "prev_round" in self.client_state.parameters_records:
                idx  = torch.from_numpy(self.client_state.parameters_records["prev_round"]["indices"].numpy()).long().to(self.device)
                val  = torch.from_numpy(self.client_state.parameters_records["prev_round"]["values"].numpy()).to(self.device)
                if self.alpha_r is not None:
                    val = val * (1-self.alpha_r)
                flat_delta_parameters[idx] += val          # add e_t,i
            if self.comp_type == "topk":
                _, spars_indices = torch.topk(torch.abs(flat_delta_parameters), int(self.sparsify_by * len(flat_delta_parameters)))
            elif self.comp_type == "randk":
                spars_indices = torch.randperm(len(flat_delta_parameters))[:int(self.sparsify_by * len(flat_delta_parameters))]
            else:
                raise ValueError(f"Unknown comp_type: {self.comp_type}")

            # Grab BN buffers for server-side evaluation
            bn_mu, bn_var = [], []
            for m in self.net.modules():
                if isinstance(m, torch.nn.BatchNorm2d):
                    bn_mu.append(m.running_mean)
                    bn_var.append(m.running_var)

            # --- Top-k ------------------------------------------
            delta_idx = spars_indices.cpu().numpy().astype(np.int64)      # int64 indices
            delta_val = flat_delta_parameters[spars_indices].cpu().numpy().astype(np.float32)    # float32 values

            # --- BatchNorm stats (optional, for central eval) -----
            bn_mu  = torch.cat(bn_mu ).cpu().numpy().astype(np.float16)   # running_mean
            bn_var = torch.cat(bn_var).cpu().numpy().astype(np.float16)   # running_var

            # --- Weight for weighted average ----------------------
            count  = np.array([len(self.trainloader.dataset)], dtype=np.float32)
            flower_payload: list[np.ndarray] = [delta_idx, delta_val, bn_mu, bn_var, count]

                # ---------- update residual  e_{t+1} ----------
            residual = flat_delta_parameters.clone()
            residual[spars_indices] = 0.0                      # keep only the part we did *not* send

            nonzero_idx = torch.nonzero(residual, as_tuple=False).squeeze(1)
            nonzero_val = residual[nonzero_idx]

            # ---------- store residual and batchnorm stats for the next round ----------
            record = ParametersRecord()
            record["indices"] = array_from_numpy(nonzero_idx.cpu().numpy())
            record["values"]  = array_from_numpy(nonzero_val.cpu().numpy())
            self.client_state.parameters_records["prev_round"] = record

            # Compute comm bits from the actual payload and write the JSON record
            try:
                bits = _comm_bits_from_payload(flower_payload, include_bn=True)
                # Update cached record and write
                if hasattr(self, "_pre_record_cache"):
                    rec = dict(self._pre_record_cache)
                    rec.update(bits)
                    append_jsonl(f"logs/clients_metrics/{self.approach}/metrics_client_{self.node_id}.jsonl", rec)
                    # also stash for fit(...) to attach to Flower metrics
                    if hasattr(self, "_pre_metrics_cache"):
                        self._pre_metrics_cache["uplink_bits_total"] = float(bits["uplink_bits_total"])
                        self._metrics_for_flower = dict(self._pre_metrics_cache)
            except Exception as ex:
                # don’t crash training on logging issues
                print(f"[warn] metrics logging failed: {ex}")

            return flower_payload
    
    def set_parameters(self, parameters, config):
        """Set the local model parameters using given ones."""
        params_dict = zip(self.net.state_dict().keys(), parameters)
        state_dict = OrderedDict()

        for k, v in params_dict:
            try:
                if k.endswith('num_batches_tracked'):
                    tensor = torch.tensor(v, dtype=torch.long)
                else:
                    tensor = torch.Tensor(v)
                state_dict[k] = tensor
            except Exception as e:
                print(f"Error converting parameter {k}: {e}")

        if not state_dict:
            raise ValueError("State dictionary is empty!")

        try:
            self.net.load_state_dict(state_dict, strict=True)
        except Exception as e:
            print(f"Error loading state dict with strict=True: {e}")
            self.net.load_state_dict(state_dict, strict=False)
    
    def fit(self, parameters, config: Dict[str, Scalar]):
        """Implement distributed fit function for a given client."""

        self.set_parameters(parameters, config)
        
        eta_max = 0.01
        eta_min = 0.003
        learning_rate = cosine_eta_warmfloor(
            r=config["server_round"], total_R=400,
            eta_max=eta_max, eta_min=eta_min, warmup_R=10
        )
        flat_global = torch.nn.utils.parameters_to_vector([p.data for p in self.net.parameters() if p.requires_grad]).to(self.device)
        # ------------------------------------------------------------------
        # 1.  Re-inject the residual (“one step ahead”)
        # ------------------------------------------------------------------
        if config["server_round"] > 1 and self.client_state is not None and self.client_state.parameters_records is not None and "prev_round" in self.client_state.parameters_records: # no residual in round 1
            rec = self.client_state.parameters_records["prev_round"]
            idx = torch.from_numpy(rec["indices"].numpy()).long().to(self.device)
            val = torch.from_numpy(rec["values" ].numpy()).to(self.device)
            e_t = torch.zeros_like(flat_global)
            e_t[idx] = val                                    # dense residual
        else:
            e_t        = torch.zeros_like(flat_global)
    
        # Gradient mismatch probe:
        alpha_preview = float(getattr(self, "alpha_r", 0.0))   # for CSER/EF: 0.0 ; for SA-PEF: >0

        # Pre-compute "pre-send" metrics (no payload yet, so bits will be filled later)
        pre_record, pre_metrics = compute_metrics_before_send(
            model=self.net,
            device=self.device,
            valloader=self.valloader,
            flat_w_r=flat_global.detach().clone(),
            e_t_dense=e_t.detach().clone(),
            alpha_preview=alpha_preview,
            include_bn_in_bits=False,
            server_round=int(config["server_round"]),
            sparsify_by=self.sparsify_by,
            learning_rate=learning_rate,
            num_epochs=self.num_epochs,
            comp_type=self.comp_type,
            L_est=1.0,
            payload_preview=None,
        )
        # Temporarily stash; we'll update the 'bits' part later and write JSON
        self._pre_record_cache = pre_record
        self._pre_metrics_cache = pre_metrics

        theta_half = flat_global + self.alpha_r * e_t             
        # write theta_half back into the model so local SGD starts from it
        torch.nn.utils.vector_to_parameters(
            theta_half, [p for p in self.net.parameters() if p.requires_grad]
        )

        train_fedavg(
            self.net,
            self.trainloader,
            self.device,
            self.num_epochs,
            learning_rate,
            self.momentum,
            self.weight_decay,
        )
        final_p_np = self.get_parameters(config, parameters)
        metrics = getattr(self, "_metrics_for_flower", {})
        return final_p_np, len(self.trainloader.dataset), metrics

    def evaluate(self, parameters, config):
        print("Evaluating model with received parameters.")
        try:
            self.set_parameters(parameters, config)
        except Exception as e:
            print(f"Error in set_parameters: {e}")
            raise e
        print("Parameters set successfully, proceeding to evaluation.")
        loss, acc = test(self.net, self.valloader, self.device)
        return float(loss), len(self.valloader.dataset), {"accuracy": float(acc)}

# pylint: disable=too-many-arguments
def gen_client_fn(
    trainloaders: List[DataLoader],
    valloaders: List[DataLoader],
    num_epochs: int,
    learning_rate: float,
    model: DictConfig,
    momentum: float = 0.9,
    weight_decay: float = 1e-5,
    approach: str = "random",
    sparsify_by: float = 0.5,
    comp_type: str = "topk",
    alpha_r: float = 0.5,
) -> Callable[[str], FlowerClientFedAvg]:  # pylint: disable=too-many-arguments
    """Generate the client function that creates the FedAvg flower clients.

    Parameters
    ----------
    trainloaders: List[DataLoader]
        A list of DataLoaders, each pointing to the dataset training partition
        belonging to a particular client.
    valloaders: List[DataLoader]
        A list of DataLoaders, each pointing to the dataset validation partition
        belonging to a particular client.
    num_epochs : int
        The number of local epochs each client should run the training for before
        sending it to the server.
    learning_rate : float
        The learning rate for the SGD optimizer of clients.
    momentum : float
        The momentum for SGD optimizer of clients
    weight_decay : float
        The weight decay for SGD optimizer of clients

    Returns
    -------
    Callable[[str], FlowerClientFedAvg]
        The client function that creates the FedAvg flower clients
    """

    def client_fn(context:Context): # -> FlowerClientFedAvg
        """Create a Flower client representing a single organization."""
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        net = instantiate(model).to(device)

        # Note: each client gets a different trainloader/valloader, so each client
        # will train and evaluate on their own unique data
        trainloader = trainloaders[int(context.node_config["partition-id"])]
        valloader = valloaders[int(context.node_config["partition-id"])]

        return FlowerClientFedAvg(
            net,
            trainloader,
            valloader,
            device,
            num_epochs,
            learning_rate,
            momentum,
            weight_decay,
            approach,
            sparsify_by,
            context,
            comp_type,
            alpha_r,
        ).to_client()

    return client_fn
