import json, time, math
from pathlib import Path
import torch
import torch.nn as nn

# ---------- JSONL logging ----------
def append_jsonl(path: str, record: dict):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "a") as f:
        f.write(json.dumps(record) + "\n")

# ---------- flatten & restore ----------
def _flat_params(module: nn.Module):
    return torch.nn.utils.parameters_to_vector([p for p in module.parameters() if p.requires_grad])

def _set_flat_params(module: nn.Module, vec: torch.Tensor):
    torch.nn.utils.vector_to_parameters(vec, [p for p in module.parameters() if p.requires_grad])

# ---------- single-batch gradient (vector) ----------
@torch.no_grad()
def _next_batch(dataloader, max_samples=None):
    # Return one (possibly concatenated) batch
    it = iter(dataloader)
    x, y = next(it)
    if max_samples is not None and x.shape[0] > max_samples:
        x, y = x[:max_samples], y[:max_samples]
    return x, y

def _grad_vector(model, loss_fn, x, y, device):
    was_training = model.training
    model.eval()                                   # freeze BN/dropout behavior
    model.zero_grad(set_to_none=True)
    x, y = x.to(device), y.to(device)
    with torch.enable_grad():
        out = model(x)
        loss = loss_fn(out, y)
        grads = torch.autograd.grad(
            loss, [p for p in model.parameters() if p.requires_grad],
            create_graph=False, retain_graph=False
        )
    g = torch.cat([g.reshape(-1) for g in grads]).detach()
    model.train(was_training)                       
    return g

# ---------- comm bits ----------
def _bits_for_array(np_array) -> int:
    # Approximates payload size: use dtype.nbytes * count
    return int(np_array.nbytes / (1024**2))

def _comm_bits_from_payload(payload: list, include_bn=False) -> dict:
    # [delta_idx, delta_val, bn_mu, bn_var, count] (+ optional [e_idx, e_val])
    bits_update = _bits_for_array(payload[0]) + _bits_for_array(payload[1])
    bits_bn = (_bits_for_array(payload[2]) + _bits_for_array(payload[3]) + _bits_for_array(payload[4])) if include_bn else 0
    bits_err = 0
    if len(payload) >= 7:
        bits_err = _bits_for_array(payload[5]) + _bits_for_array(payload[6])
    # small metadata overhead (~ a few small ints); account ~128 bits
    meta_bits = 128
    total = bits_update + bits_err + (bits_bn if include_bn else 0) + meta_bits
    return {"uplink_bits_update": bits_update, "uplink_bits_error": bits_err,
            "uplink_bits_bn": bits_bn, "uplink_bits_total": total}

# ---------- theory plugs ----------
def _delta_from_comp(update_frac: float, comp_type: str) -> float:
    # Common δ for Top-k/Rand-k analyses ≈ 1 / keep_fraction
    keep = max(1e-6, float(update_frac))
    return 1.0 / keep

def _compute_rho(delta: float, alpha_r: float, s_r: float) -> float:
    # ρ_r = (1 - 1/δ) [ 2(1-α_r)^2 + 24 s_r^2 α_r^2 ]
    return (1.0 - 1.0/delta) * (2.0*(1.0 - alpha_r)**2 + 24.0*(s_r**2) * (alpha_r**2))

def _estimate_s_r(lr: float, L_est: float, T_local: int) -> float:
    # simple smoothness-based proxy often used in proofs
    return float(lr * L_est * T_local)

# ---------- master metric helper ----------
def compute_metrics_before_send(
    *,
    model: nn.Module,
    device: torch.device,
    valloader,                      # use the validation loader (held-out)
    flat_w_r: torch.Tensor,         # current global weights at round r
    e_t_dense: torch.Tensor,        # dense residual at start of round
    alpha_preview: float,           # α used for gradient mismatch probe
    include_bn_in_bits: bool,
    server_round: int,
    sparsify_by: float,
    learning_rate: float,
    num_epochs: int,
    comp_type: str,
    L_est: float = 1.0,
    payload_preview=None            
):
    """
    Returns a dict with:
      grad_norm_sq, residual_energy, grad_mismatch_sq, rho_r, s_r,
      uplink_bits_total (and breakdown if payload_preview provided).
    Also returns a 'record' dict for JSONL logging.
    """
    loss_fn = nn.CrossEntropyLoss(reduction="mean")

    # --- build two evaluation points: w_r and w_r - α e_t ---
    w_backup = _flat_params(model).to(device)  # save current
    # 1) grad at w_r
    _set_flat_params(model, flat_w_r)
    x, y = _next_batch(valloader, max_samples=1024)
    g_wr = _grad_vector(model, loss_fn, x, y, device)
    grad_norm_sq = float(g_wr.pow(2).sum().item())

    # 2) grad at preview point
    preview = (flat_w_r - alpha_preview * e_t_dense).to(device)
    _set_flat_params(model, preview)
    g_prev = _grad_vector(model, loss_fn, x, y, device)
    grad_mismatch_sq = float((g_wr - g_prev).pow(2).sum().item())

    # restore model
    _set_flat_params(model, w_backup)

    # residual energy
    residual_energy = float(e_t_dense.pow(2).sum().item())

    # contraction proxy
    delta = _delta_from_comp(update_frac=float(sparsify_by), comp_type=str(comp_type))
    s_r = _estimate_s_r(lr=float(learning_rate), L_est=float(L_est), T_local=int(num_epochs))
    rho_r = _compute_rho(delta=delta, alpha_r=float(alpha_preview), s_r=s_r)

    # comm bits (if preview payload provided)
    bits = {"uplink_bits_update": 0, "uplink_bits_error": 0, "uplink_bits_bn": 0, "uplink_bits_total": 0}
    if payload_preview is not None:
        bits = _comm_bits_from_payload(payload_preview, include_bn=include_bn_in_bits)

    # pack
    record = {
        "round": server_round,
        "grad_norm_sq": grad_norm_sq,
        "residual_energy": residual_energy,
        "grad_mismatch_sq": grad_mismatch_sq,
        "rho_r": rho_r,
        "s_r": s_r,
        **bits,
        "alpha_r": float(alpha_preview),
        "lr": float(learning_rate),
        "local_epochs": int(num_epochs),
        "sparsify_by": float(sparsify_by),
        "comp_type": str(comp_type),
    }
    # scalars for Flower FitRes.metrics (must be JSON-serializable scalars)
    metrics_for_flower = {
        "grad_norm_sq": grad_norm_sq,
        "residual_energy": residual_energy,
        "grad_mismatch_sq": grad_mismatch_sq,
        "rho_r": rho_r,
        "uplink_bits_total": float(bits["uplink_bits_total"]),
    }
    return record, metrics_for_flower
