"""Layerwise quantization pipeline.

Goal
----
Quantize a model one layer (weight matrix) at a time.
For each new layer:
  1) run the *partially quantized* model on a calibration set
  2) compute the Hessian/covariance statistic for the target module
  3) quantize that module's weight with GPTQ or ZSIC
  4) save the artifact, apply it to the model
  5) repeat

Artifacts are always saved to disk so you can resume.
"""

from __future__ import annotations

import json
import os
import time
from dataclasses import dataclass, asdict, replace
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import torch

from quant_layerwise.bucket import get_bucket_path, model_reg
from quant_layerwise.data import get_wikitext2, split_dataset, take_nseq
from quant_layerwise.hessian_runtime import (
    compute_module_hessian,
    compute_module_hessian_cached,
    ActivationCache,
)
from quant_layerwise.methods.gptq import GPTQConfig, compress_gptq_wrapper as compress_gptq
from quant_layerwise.methods.zsic import ZSICConfig, compress_zsic
from quant_layerwise.partial_model import apply_layer_artifact
from quant_layerwise.storage import LayerArtifact, RunManifest, safe_stem
from quant_layerwise.names import get_hess_name, get_weight_name
from quant_layerwise.hadamard import (
    BlockRandomHadamard,
    HadamardType,
    HadamardConfig,
    apply_hadamard_to_weight,
    apply_hadamard_to_hessian,
    inverse_hadamard_weight,
)
from quant_layerwise.rate_control import RateControlConfig, RateController


def _infer_is_llama2(model_name: str) -> bool:
    return str(model_name).startswith("2-")


def _is_qwen3(model_name: str) -> bool:
    return str(model_name).lower().startswith("qwen3")


def ensure_single_process_distributed(*, local_rank: int, master_port: int = 29500):
    """Allow calling `parallel.start.start(...)` without torchrun, *if* WORLD_SIZE=1.

    This is only safe when:
      * you are running a single process per job
      * your ckpt_dir contains exactly one `.pth` shard

    If your checkpoints are sharded (len(checkpoints)>1), you still need torchrun
    with WORLD_SIZE equal to the shard count.
    """
    # Use direct assignment (not setdefault) to ensure each spawned process
    # gets its own unique port, even if parent process set these vars
    os.environ["MASTER_ADDR"] = "127.0.0.1"
    os.environ["MASTER_PORT"] = str(int(master_port))
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"
    # IMPORTANT: Always set LOCAL_RANK=0 because each process is rank 0 of its own
    # single-process distributed group. The parallel/start.py silences output for
    # LOCAL_RANK>0, which would hide logs from processes on GPU>0.
    os.environ["LOCAL_RANK"] = "0"
    print(f"[dist] MASTER_PORT={master_port}, GPU={local_rank}", flush=True)


def _load_qwen3(
    model_name: str,
    *,
    local_rank: int = 0,
    max_seq_len: int | None = None,
):
    """Load Qwen3 model from HuggingFace.

    Args:
        model_name: Name of the model (e.g., "qwen3-8B")
        local_rank: GPU device to use
        max_seq_len: Override max sequence length for KV cache

    Returns:
        model, tokenizer (with Llama-compatible interface)
    """
    from transformers import AutoModelForCausalLM, AutoTokenizer
    from parallel.qwen3_adapter import Qwen3ToLlamaAdapter, Qwen3TokenizerAdapter

    # Get HuggingFace path from registry
    hf_path = model_reg[model_name]

    # Check for local cache in QUANT_BUCKET
    # Try multiple possible folder names: Qwen3-8B, qwen3-8B, Qwen_Qwen3-8B
    bucket = get_bucket_path()
    local_candidates = [
        bucket / "Qwen3-8B",           # Most common manual download name
        bucket / model_name,            # e.g., qwen3-8B
        bucket / hf_path.replace("/", "_"),  # e.g., Qwen_Qwen3-8B
    ]
    local_path = None
    for candidate in local_candidates:
        if candidate.exists():
            local_path = candidate
            break

    if local_path is not None:
        hf_path = str(local_path)
        print(f"[qwen3] loading from local: {hf_path}", flush=True)
    else:
        print(f"[qwen3] loading from HuggingFace: {hf_path}", flush=True)
        print(f"[qwen3] tip: to cache locally, run: huggingface-cli download {model_reg[model_name]} --local-dir $QUANT_BUCKET/Qwen3-8B", flush=True)

    # Load tokenizer
    hf_tokenizer = AutoTokenizer.from_pretrained(hf_path, trust_remote_code=True)
    tokenizer = Qwen3TokenizerAdapter(hf_tokenizer)

    # Load model
    hf_model = AutoModelForCausalLM.from_pretrained(
        hf_path,
        torch_dtype=torch.bfloat16,
        device_map=f"cuda:{local_rank}",
        trust_remote_code=True,
    )

    # Wrap in adapter
    override_params = {}
    if max_seq_len is not None:
        override_params["max_seq_len"] = max_seq_len

    model = Qwen3ToLlamaAdapter(hf_model, override_params)

    return model, tokenizer


def load_model_and_tokenizer(
    model_name: str,
    *,
    local_rank: int = 0,
    max_seq_len: int | None = None,
):
    """Load your Transformer + tokenizer using your existing `parallel.start.start`.

    Args:
        model_name: Name of the model to load (e.g., "3.2-1B", "3-8B", "qwen3-8B")
        local_rank: GPU device to use
        max_seq_len: Override max sequence length for KV cache and RoPE.
                     If None, uses model default (typically 2048).
                     Set to 4096 for longer context evaluation.

    Returns:
        model, tokenizer
    """
    # Handle Qwen3 models via HuggingFace
    if _is_qwen3(model_name):
        return _load_qwen3(model_name, local_rank=local_rank, max_seq_len=max_seq_len)

    # Llama models via existing loading code
    from parallel.start import start
    from parallel.config import no_q_config

    ckpt_dir = get_bucket_path() / model_reg[model_name]
    is_llama2 = _infer_is_llama2(model_name)

    override_params = {}
    if max_seq_len is not None:
        override_params["max_seq_len"] = max_seq_len

    model, tokenizer = start(str(ckpt_dir), is_llama2, no_q_config, override_params=override_params)
    return model, tokenizer


@dataclass(frozen=True)
class PipelineConfig:
    model_name: str
    method: str  # "gptq" | "zsic"

    # Layer order list of tuples (layer_id, weight) using your naming.
    layers: Sequence[Tuple[int, str]]

    # Calibration
    seqlen: int = 2048
    calib_nsamples: Optional[int] = None  # None means use all available samples
    hessian_batch_size: int = 1  # Batch size for Hessian computation (higher = faster but more memory)

    # Hadamard (same settings for all layers in this run)
    hadamard: bool = False
    hadamard_type: str = "row"  # "none", "row", "column", "row_column"
    hadamard_seed: int = 0

    # Output
    run_root: str = "quant_runs"  # relative to QUANT_BUCKET
    run_id: str = ""
    resume: bool = True

    # Method-specific configs
    gptq: Optional[GPTQConfig] = None
    zsic: Optional[ZSICConfig] = None

    # Optional: global rate control (mainly for ZSIC).
    rate_control: Optional[RateControlConfig] = None

    # Skip quantization for specific (layer_id, weight) pairs - store in full precision
    # Format: list of tuples like [(0, "wq"), (0, "wk"), (1, "wq"), (1, "wk")]
    skip_quantize_layers: Sequence[Tuple[int, str]] = ()

    # Qronos: compute and save Σ_X̂ and Σ_XX̂ statistics for each layer
    # This requires maintaining both unquantized and quantized model copies
    qronos: bool = False

    # Qronos layer range: only apply Qronos targeting (W*) to layers in [min, max)
    # Stats are still computed for all layers. Use None for no limit.
    # Example: qronos_layer_min=14 means only layers 14+ use Qronos targeting
    qronos_layer_min: Optional[int] = None
    qronos_layer_max: Optional[int] = None

    # Skip Qronos targeting for specific (layer_id, weight) pairs (stats still computed)
    # Format: list of tuples, e.g., [(2, "wo"), (3, "wo")] to skip L2_wo and L3_wo
    qronos_skip_layers: Sequence[Tuple[int, str]] = ()

    # Skip Qronos targeting for specific weight types globally (all layers)
    # Format: list of weight names, e.g., ["wq", "wk", "wv"] to skip all Q/K/V
    qronos_skip_weights: Sequence[str] = ()

    # Skip Qronos for Q/K/V (wq, wk, wv) in the first N layers
    # Use this to avoid error accumulation in early attention layers
    qronos_skip_qkv_prefix: int = 0

    # Auto-skip Qronos when min(diag(Σ_{X,X̂})) < threshold
    # This detects degenerate cross-covariance (near-zero correlation between X and X̂)
    # Set to 0 to disable auto-skip. Recommended: 1e-5 or 1e-6
    qronos_auto_skip_min_diag: float = 0.0

    # Collect Qronos stats for diagnostics even when qronos=False
    # This allows plotting activation MSE without applying Qronos targeting
    collect_qronos_stats: bool = False

    # Plot activation MSE at the end of the run (requires collect_qronos_stats or qronos)
    plot_activation_mse: bool = False

    # Use Hessians from unquantized model (avoids error propagation in layerwise quant)
    # Similar to Qronos but only for Hessian computation, not the full Qronos statistics
    unquant_hessians: bool = False

    # Residual stream compensation for wo/w2 layers
    # When enabled, modifies the quantization target to account for residual stream error:
    # ŷ = (W Σ_{X,X̂} + Σ_{ΔR,X̂}) (L̂^T)^{-1}  where Σ_{ΔR,X̂} = E[(R - R̂)X̂^T]
    # Requires both unquant and quant models (automatically loads unquant model if needed)
    # Can be used with or without qronos=True
    residual_compensation: bool = False
    # Skip residual compensation on the first N layers (0 = apply to all)
    # e.g., rescomp_skip_prefix=8 means skip layers 0-7, apply to layers 8+
    rescomp_skip_prefix: int = 0


def default_run_id(cfg: PipelineConfig) -> str:
    ts = time.strftime("%Y%m%d_%H%M%S")
    # Include target rate in the naming
    if cfg.gptq is not None:
        rate = cfg.gptq.target_rate
    elif cfg.zsic is not None:
        rate = cfg.zsic.target_rate_bits
    else:
        rate = 0
    base = f"{cfg.model_name}.{cfg.method}.r{rate:.2f}.{ts}"
    return base


def get_run_dir(cfg: PipelineConfig) -> Path:
    root = Path(cfg.run_root)
    if not root.is_absolute():
        root = get_bucket_path() / root
    rid = cfg.run_id or default_run_id(cfg)
    return root / cfg.model_name / rid


def should_skip_quantize(cfg: PipelineConfig, layer_id: int, weight: str) -> bool:
    """Check if a (layer_id, weight) pair should skip quantization."""
    for skip_layer_id, skip_weight in cfg.skip_quantize_layers:
        if int(skip_layer_id) == int(layer_id) and str(skip_weight).lower() == str(weight).lower():
            return True
    return False


@torch.no_grad()
def create_fullprec_artifact(
    model: torch.nn.Module,
    module_name: str,
    weight_name: str,
) -> LayerArtifact:
    """Create an artifact that stores the weight in full precision (no quantization)."""
    module = dict(model.named_modules())[module_name]
    W = module.weight.detach()

    payload: Dict[str, Any] = {
        "W_full": W.to(torch.float16).cpu(),  # Store in fp16 to save space
        "loss": 0.0,
        "entropy": 16.0,  # 16 bits per param (fp16)
        "rate_overhead": 0.0,
    }

    return LayerArtifact(
        method="fullprec",
        module_name=module_name,
        weight_name=weight_name,
        shape=tuple(int(x) for x in W.shape),
        payload=payload,
        hadamard={},
    )


@torch.no_grad()
def quantize_one_layer(
    *,
    model: torch.nn.Module,
    module_name: str,
    weight_name: str,
    H: torch.Tensor,
    cfg: PipelineConfig,
    gptq_cfg: Optional[GPTQConfig] = None,
    zsic_cfg: Optional[ZSICConfig] = None,
    # Qronos stats (optional)
    Sig_X: Optional[torch.Tensor] = None,
    Sig_hX: Optional[torch.Tensor] = None,
    Sig_X_hX: Optional[torch.Tensor] = None,
    # Residual compensation (optional, for wo/w2 layers)
    Sig_delta_R_Xhat: Optional[torch.Tensor] = None,
) -> LayerArtifact:
    """Quantize exactly one module weight and return the artifact (not yet saved).

    For Qronos mode (when Qronos stats are provided):
    - Sig_X is the unquantized activations covariance E[X X^T]
    - Sig_hX is the quantized activations covariance E[X_hat X_hat^T]
    - Sig_X_hX is the cross-covariance E[X X_hat^T]
    - H is used as fallback for the Hessian in non-Qronos mode

    For residual compensation (for wo/w2 layers):
    - Sig_delta_R_Xhat is E[(R - R̂) X̂^T] where R is the residual stream
    """
    module = dict(model.named_modules())[module_name]

    # Pull current weight (still fp). Work in float32 where possible.
    W0 = module.weight.detach()
    work_dtype = torch.float32 if W0.dtype in (torch.float16, torch.bfloat16) else W0.dtype
    W = W0.to(work_dtype)
    H_work = H.to(work_dtype)

    # Convert Qronos stats to work dtype if provided
    Sig_X_work = Sig_X.to(work_dtype) if Sig_X is not None else None
    Sig_hX_work = Sig_hX.to(work_dtype) if Sig_hX is not None else None
    Sig_X_hX_work = Sig_X_hX.to(work_dtype) if Sig_X_hX is not None else None
    # Residual compensation term: shape (out_features, in_features)
    Sig_delta_R_Xhat_work = Sig_delta_R_Xhat.to(work_dtype) if Sig_delta_R_Xhat is not None else None

    # Parse Hadamard type
    hadamard_type = HadamardType.NONE
    if cfg.hadamard:
        type_map = {
            "none": HadamardType.NONE,
            "row": HadamardType.ROW,
            "column": HadamardType.COLUMN,
            "row_column": HadamardType.ROW_COLUMN,
        }
        hadamard_type = type_map.get(cfg.hadamard_type.lower(), HadamardType.ROW)

    had_cfg = {
        "enabled": bool(cfg.hadamard) and hadamard_type != HadamardType.NONE,
        "type": str(hadamard_type.value),
        "seed": int(cfg.hadamard_seed),
    }

    # Create Hadamard transforms for row and column dimensions
    had_row = None
    had_col = None
    out_dim, in_dim = W.shape

    if hadamard_type in (HadamardType.COLUMN, HadamardType.ROW_COLUMN):
        had_row = BlockRandomHadamard(out_dim, seed=cfg.hadamard_seed, device=W.device, dtype=work_dtype)

    if hadamard_type in (HadamardType.ROW, HadamardType.ROW_COLUMN):
        had_col = BlockRandomHadamard(in_dim, seed=cfg.hadamard_seed, device=W.device, dtype=work_dtype)

    # Apply Hadamard transform to weights and Hessian
    # Row Hadamard (W @ H_col): Hessian transforms as H_col^T @ Sigma @ H_col
    # Column Hadamard (H_row @ W): Hessian unchanged
    # Row + Column (H_row @ W @ H_col): Hessian transforms as above
    Wq_in = apply_hadamard_to_weight(W, hadamard_type, had_row=had_row, had_col=had_col)
    Hq_in = apply_hadamard_to_hessian(H_work, hadamard_type, had=had_col)

    # Apply same Hadamard transform to Qronos stats (covariance matrices)
    # They transform the same way as the Hessian: H^T @ Sigma @ H
    if Sig_X_work is not None:
        Sig_X_work = apply_hadamard_to_hessian(Sig_X_work, hadamard_type, had=had_col)
    if Sig_hX_work is not None:
        Sig_hX_work = apply_hadamard_to_hessian(Sig_hX_work, hadamard_type, had=had_col)
    if Sig_X_hX_work is not None:
        Sig_X_hX_work = apply_hadamard_to_hessian(Sig_X_hX_work, hadamard_type, had=had_col)
    # Residual compensation term transforms the same way as weights (H_row @ Sig @ H_col)
    if Sig_delta_R_Xhat_work is not None:
        Sig_delta_R_Xhat_work = apply_hadamard_to_weight(
            Sig_delta_R_Xhat_work, hadamard_type, had_row=had_row, had_col=had_col
        )

    method = cfg.method.lower()

    if method == "gptq":
        gcfg = gptq_cfg if gptq_cfg is not None else cfg.gptq
        if gcfg is None:
            raise ValueError("PipelineConfig.gptq must be provided for method='gptq'")
        What, loss, rate, frame = compress_gptq(Wq_in, Hq_in, cfg=gcfg)

        payload: Dict[str, Any] = {
            "Qint": frame["Qint"].to(torch.uint8).cpu(),
            "scales": frame["scales"].to(torch.float16).cpu() if frame["scales"] is not None else None,
            "zeros": frame["zeros"].to(torch.float16).cpu() if frame["zeros"] is not None else None,
            "groupsize": int(frame["groupsize"]),
            "blocksize": int(frame["blocksize"]),
            "percdamp": float(frame["percdamp"]),
            "actorder": bool(frame["actorder"]),
            "maxq": int(frame["maxq"]),
            "target_rate": float(gcfg.target_rate),
            "entropy": float(frame["entropy"]),
            "rate_overhead": float(frame["rate_overhead"]),
            "loss": float(loss),
            "relative_mse": float(frame["relative_mse"]),
        }

        art = LayerArtifact(
            method="gptq",
            module_name=module_name,
            weight_name=weight_name,
            shape=tuple(int(x) for x in W0.shape),
            hadamard=had_cfg if cfg.hadamard else {"enabled": False},
            payload=payload,
        )
        return art

    if method in ("zsic", "sic"):
        zcfg = zsic_cfg if zsic_cfg is not None else cfg.zsic
        if zcfg is None:
            raise ValueError("PipelineConfig.zsic must be provided for method='zsic'")

        # Pass Qronos stats and residual compensation if available
        What, loss, rate, frame = compress_zsic(
            Wq_in, Hq_in, cfg=zcfg,
            Sig_X=Sig_X_work,
            Sig_hX=Sig_hX_work,
            Sig_X_hX=Sig_X_hX_work,
            Sig_delta_R_Xhat=Sig_delta_R_Xhat_work,
        )

        payload = {
            "Z": frame["Z"].to(torch.int32).cpu(),
            "alpha": frame["alpha"].to(torch.float16).cpu(),
            "alpha_base": frame.get("alpha_base", None).to(torch.float16).cpu() if frame.get("alpha_base", None) is not None else None,
            "zero_point": frame.get("zero_point", None).to(torch.float16).cpu() if frame.get("zero_point", None) is not None else None,
            "apply_tgamma": bool(frame.get("apply_tgamma", False)),
            "t_vec": frame.get("t_vec", None).to(torch.float16).cpu() if frame.get("t_vec", None) is not None else None,
            "g_vec": frame.get("g_vec", None).to(torch.float16).cpu() if frame.get("g_vec", None) is not None else None,
            "sic_variant": str(frame.get("sic_variant", "compress_w2q")),
            "target_rate_bits": float(frame.get("target_rate_bits", zcfg.target_rate_bits)),
            "entropy": float(frame.get("entropy", 0.0)),
            "rate_overhead": float(frame.get("rate_overhead", 0.0)),
            "loss": float(loss),
            "c_param": float(frame.get("c_param", 0.0)),
            "percdamp": float(frame.get("percdamp", 0.0)),
            "hessian_damp_used": float(frame.get("hessian_damp_used", 0.0)),
            "cholesky_tries": int(frame.get("cholesky_tries", 0)),
            "qronos": bool(frame.get("qronos", False)),
            "residual_compensation": bool(frame.get("residual_compensation", False)),
            # Binary search fields (when binary_search=True in config)
            "binary_search_target_used": frame.get("binary_search_target_used", None),
            "binary_search_desired": frame.get("binary_search_desired", None),
            "binary_search_final_diff": frame.get("binary_search_final_diff", None),
            "binary_search_iterations": frame.get("binary_search_iterations", None),
        }

        art = LayerArtifact(
            method="zsic",
            module_name=module_name,
            weight_name=weight_name,
            shape=tuple(int(x) for x in W0.shape),
            hadamard=had_cfg if cfg.hadamard else {"enabled": False},
            payload=payload,
        )
        return art

    raise ValueError(f"Unknown method: {cfg.method!r}")


def _artifact_relpath(module_name: str, *, method: str) -> str:
    stem = safe_stem(module_name)
    return str(Path("layers") / f"{stem}.{method}.pt")


def run_pipeline(cfg: PipelineConfig, *, local_rank: int = 0) -> Path:
    """Run a full layerwise quantization job.

    Returns:
        Path to run directory.
    """
    run_dir = get_run_dir(cfg)
    run_dir.mkdir(parents=True, exist_ok=True)

    manifest_path = run_dir / "manifest.json"
    log_path = run_dir / "layer_logs.jsonl"

    # Load or create manifest.
    if cfg.resume and manifest_path.exists():
        manifest = RunManifest.load(manifest_path)
    else:
        manifest = RunManifest(
            model_name=str(cfg.model_name),
            method=str(cfg.method),
            run_id=str(run_dir.name),
            config=asdict(cfg),
            artifacts={},
        )
        manifest.save(manifest_path)

    # Load model/tokenizer and apply any already-saved layers (resume).
    print(f"[pipeline] [{cfg.run_id}] loading model...", flush=True)
    model, tokenizer = load_model_and_tokenizer(cfg.model_name, local_rank=local_rank)
    print(f"[pipeline] [{cfg.run_id}] model loaded", flush=True)

    # Need second unquantized model when:
    # - Qronos: computes X (original activations) for Σ_X, Σ_XX̂ statistics + applies Qronos targeting
    # - collect_qronos_stats: computes stats for diagnostics only (no Qronos targeting)
    # - unquant_hessians: uses unquantized model for Hessian computation (avoids error propagation)
    # - residual_compensation: computes Σ_{ΔR,X̂} for wo/w2 layers
    # The main `model` becomes `model_quant` (progressively quantized)
    need_unquant_model = cfg.qronos or cfg.collect_qronos_stats or cfg.unquant_hessians or cfg.residual_compensation
    model_unquant = None
    if need_unquant_model:
        reasons = []
        if cfg.qronos:
            reasons.append("qronos")
        if cfg.collect_qronos_stats:
            reasons.append("collect_qronos_stats")
        if cfg.unquant_hessians:
            reasons.append("unquant_hessians")
        if cfg.residual_compensation:
            reasons.append("residual_compensation")
        print(f"[{'+'.join(reasons)}] loading second model copy (unquantized)", flush=True)
        model_unquant, _ = load_model_and_tokenizer(cfg.model_name, local_rank=local_rank)
        # model_unquant stays frozen (never gets quantized weights applied)

    # Resume: apply existing artifacts to model for correct Hessians.
    # For Qronos: only apply to the quantized model, not the unquantized one.
    print(f"[pipeline] [{cfg.run_id}] applying {len(manifest.artifacts)} existing artifacts...", flush=True)
    for module_name, relpath in manifest.artifacts.items():
        art = LayerArtifact.load(run_dir / relpath)
        apply_layer_artifact(model, art)

    # Optional rate control (primarily for ZSIC): global budget + regression inversion.
    rate_ctrl: RateController | None = None
    if (
        cfg.rate_control is not None
        and bool(cfg.rate_control.enabled)
        and str(cfg.method).lower() in ("zsic", "sic")
    ):
        modules = dict(model.named_modules())
        layer_meta: Dict[str, Dict[str, Any]] = {}
        for layer_id, weight in cfg.layers:
            mod = get_hess_name(layer_id, weight)
            if mod not in modules:
                raise KeyError(f"Module not found in model: {mod}")
            w = modules[mod].weight
            if w.ndim != 2:
                raise ValueError(f"Expected 2D weight for {mod}, got shape {tuple(w.shape)}")
            a, n = int(w.shape[0]), int(w.shape[1])
            layer_meta[mod] = {"numel": int(a * n), "shape": [a, n], "weight": str(weight)}

        existing: Dict[str, Dict[str, Any]] = {}
        for mod, relpath in manifest.artifacts.items():
            # Only use artifacts in this run's target layer list.
            if mod not in layer_meta:
                continue
            art = LayerArtifact.load(run_dir / relpath, map_location="cpu")
            entropy = float(art.payload.get("entropy", 0.0))
            overhead = float(art.payload.get("rate_overhead", 0.0))
            actual = float(entropy + overhead)
            # For ZSIC artifacts, this is the input parameter we passed.
            target_x = float(art.payload.get("target_rate_bits", actual))
            existing[mod] = {
                "numel": int(layer_meta[mod]["numel"]),
                "actual_rate": float(actual),
                "target_x": float(target_x),
            }

        rate_ctrl = RateController(cfg=cfg.rate_control, layer_meta=layer_meta, existing=existing)
        # Save initial controller state for reproducibility.
        rate_ctrl.save_json(str(run_dir / "rate_control_state.json"))

    # Calibration data.
    train_tokens = split_dataset(get_wikitext2(tokenizer, split="train"), cfg.seqlen)
    train_tokens = take_nseq(train_tokens, cfg.calib_nsamples)  # None means all samples
    actual_nsamples = train_tokens.shape[0]
    print(f"[calib] [{cfg.run_id}] using {actual_nsamples} calibration samples (seqlen={cfg.seqlen})", flush=True)

    # Initialize activation cache for O(N) instead of O(N^2) complexity.
    # The cache stores hidden states at transformer block boundaries.
    device = next(model.parameters()).device
    cache = ActivationCache(
        model=model,
        dataset=train_tokens,
        seqlen=int(cfg.seqlen),
        nsamples=actual_nsamples,
        device=device,
        # dtype auto-detected from model (e.g., bfloat16)
        batch_size=cfg.hessian_batch_size,  # Resize KV caches if needed
    )
    print(f"[cache] [{cfg.run_id}] initialized with {cache.nsamples} samples, starting at block 0, batch_size={cfg.hessian_batch_size}", flush=True)

    # Second cache for unquantized model when:
    # - Qronos: computes X (original activations) for Σ_X, Σ_XX̂ statistics
    # - collect_qronos_stats: computes stats for diagnostics (no Qronos targeting)
    # - unquant_hessians: uses unquantized activations for Hessian computation (avoids error propagation)
    cache_unquant = None
    if need_unquant_model and model_unquant is not None:
        cache_unquant = ActivationCache(
            model=model_unquant,
            dataset=train_tokens,
            seqlen=int(cfg.seqlen),
            nsamples=actual_nsamples,
            device=device,
            batch_size=cfg.hessian_batch_size,
        )
        print(f"[unquant-cache] initialized with {cache_unquant.nsamples} samples")

    # If resuming, advance cache through already-quantized blocks.
    # Find the first layer_id that needs to be quantized.
    layers_to_process = [(lid, w) for lid, w in cfg.layers if not manifest.has(get_hess_name(lid, w))]
    if layers_to_process:
        first_layer_id = layers_to_process[0][0]
        # Advance cache to the first block we need to process
        while cache.current_block_idx < first_layer_id:
            print(f"[cache] advancing through block {cache.current_block_idx} (resuming)")
            cache.advance_through_block(cache.current_block_idx, batch_size=cfg.hessian_batch_size)
            # Qronos/unquant_hessians: also advance unquantized cache
            if cache_unquant is not None:
                print(f"[unquant-cache] advancing through block {cache_unquant.current_block_idx} (resuming)")
                cache_unquant.advance_through_block(cache_unquant.current_block_idx, batch_size=cfg.hessian_batch_size)

    # Build a set of (layer_id, weight) pairs to know when we're at the last weight of a block.
    # We advance the cache after processing the last weight in each block.
    layer_weights_map: Dict[int, List[str]] = {}
    for layer_id, weight in cfg.layers:
        if layer_id not in layer_weights_map:
            layer_weights_map[layer_id] = []
        layer_weights_map[layer_id].append(weight)

    # Main loop.
    for layer_id, weight in cfg.layers:
        module_name = get_hess_name(layer_id, weight)
        weight_name = get_weight_name(layer_id, weight)

        if manifest.has(module_name):
            print(f"[skip] already quantized: {module_name}")
            continue

        # Check if this layer should skip quantization and use full precision
        if should_skip_quantize(cfg, layer_id, weight):
            print(f"\n[layerwise] SKIPPING quantization for {module_name} (storing full precision)", flush=True)
            art = create_fullprec_artifact(model, module_name, weight_name)

            # Save artifact
            relpath = _artifact_relpath(module_name, method=art.method)
            art.save(run_dir / relpath)

            # Update manifest
            manifest.add(module_name, relpath)
            manifest.save(manifest_path)

            # Log the skip
            line = {
                "module": module_name,
                "weight": weight_name,
                "shape": list(art.shape),
                "method": "fullprec",
                "nseq": 0,
                "ntokens": 0,
                "actual_rate": 16.0,
                "target_rate_bits_used": None,
                "rate_control": None,
                "payload": {
                    "loss": 0.0,
                    "entropy": 16.0,
                    "rate_overhead": 0.0,
                },
                "ts": time.time(),
            }
            with open(log_path, "a") as f:
                f.write(json.dumps(line) + "\n")
            continue

        print(f"\n[layerwise] [{cfg.run_id}] quantizing {module_name} ({cfg.method})", flush=True)

        had_cfg = {
            "enabled": bool(cfg.hadamard),
            "type": str(cfg.hadamard_type),
            "seed": int(cfg.hadamard_seed),
        }

        # Use cached Hessian computation - O(1) blocks instead of O(N)
        # Note: Row Hadamard doesn't require Hessian rotation, so hadamard_cfg=None
        # Normalize by tokens to get true E[X X^T] covariance matrices
        #
        # unquant_hessians mode: Use activations from unquantized model for Hessian
        # This avoids error propagation from earlier quantized layers affecting
        # Hessian quality for later layers.
        hess_model = model_unquant if cfg.unquant_hessians and model_unquant is not None else model
        hess_cache = cache_unquant if cfg.unquant_hessians and cache_unquant is not None else cache
        if cfg.unquant_hessians:
            print(f"[unquant_hessians] using unquantized model for Hessian computation", flush=True)

        H, nseq_used, ntokens_used = compute_module_hessian_cached(
            hess_model,
            hess_cache,
            layer_id,
            module_name,
            hadamard_cfg=None,  # Row Hadamard: Hessian unchanged
            normalize=True,
            normalize_by="tokens",  # True E[X X^T] normalization
            dtype=torch.float32,
            verbose=True,
            batch_size=cfg.hessian_batch_size,
        )

        # Compute Qronos stats (Σ_X, Σ_X̂, Σ_XX̂) when:
        # - qronos=True: stats are computed and used for Qronos targeting
        # - collect_qronos_stats=True: stats are computed for diagnostics only (no targeting)
        # Note: residual_compensation does NOT require full Qronos stats - it only needs Σ_{ΔR,X̂}
        qronos_stats = None
        should_compute_stats = (cfg.qronos or cfg.collect_qronos_stats) and model_unquant is not None and cache_unquant is not None
        if should_compute_stats:
            from quant_layerwise.hessian_runtime import compute_qronos_stats_cached

            qronos_stats = compute_qronos_stats_cached(
                model_unquant=model_unquant,
                model_quant=model,
                cache_unquant=cache_unquant,
                cache_quant=cache,
                layer_id=layer_id,
                module_name=module_name,
                normalize=True,
                normalize_by="tokens",  # True E[X X^T] normalization
                dtype=torch.float32,
                verbose=True,
                batch_size=cfg.hessian_batch_size,
            )
            # Note: Qronos stats are saved AFTER residual stats computation (see below)
            # so that both can be saved together for wo/w2 layers

        # Optional rate-control: choose a per-layer target_rate_bits to hit a global budget.
        zsic_cfg_layer: ZSICConfig | None = None
        rc_info: Dict[str, Any] | None = None
        target_x_used: float | None = None

        # If Qronos mode is enabled and we have stats, enable it in ZSIC config
        # But only if the layer is in the specified range [qronos_layer_min, qronos_layer_max)
        # and not in the skip list
        qronos_enabled = cfg.qronos and qronos_stats is not None
        if qronos_enabled:
            # Check layer range restrictions
            if cfg.qronos_layer_min is not None and layer_id < cfg.qronos_layer_min:
                qronos_enabled = False
                print(f"[qronos] layer {layer_id} < qronos_layer_min ({cfg.qronos_layer_min}), using standard quantization", flush=True)
            if cfg.qronos_layer_max is not None and layer_id >= cfg.qronos_layer_max:
                qronos_enabled = False
                print(f"[qronos] layer {layer_id} >= qronos_layer_max ({cfg.qronos_layer_max}), using standard quantization", flush=True)
            # Check skip list (layer_id, weight) pairs
            for skip_lid, skip_w in cfg.qronos_skip_layers:
                if int(skip_lid) == layer_id and str(skip_w).lower() == weight.lower():
                    qronos_enabled = False
                    print(f"[qronos] {layer_id}.{weight} in qronos_skip_layers, using standard quantization", flush=True)
                    break
            # Check weight-type skip list (applies to all layers)
            if qronos_enabled and weight.lower() in [w.lower() for w in cfg.qronos_skip_weights]:
                qronos_enabled = False
                print(f"[qronos] {weight} in qronos_skip_weights, using standard quantization", flush=True)
            # Check QKV prefix skip (skip wq/wk/wv in first N layers)
            if qronos_enabled and cfg.qronos_skip_qkv_prefix > 0:
                if weight.lower() in ("wq", "wk", "wv") and layer_id < cfg.qronos_skip_qkv_prefix:
                    qronos_enabled = False
                    print(f"[qronos] layer {layer_id} < qronos_skip_qkv_prefix={cfg.qronos_skip_qkv_prefix}, skipping Qronos for {weight}", flush=True)
            # Auto-skip based on min diagonal of Σ_{X,X̂}
            if qronos_enabled and cfg.qronos_auto_skip_min_diag > 0 and qronos_stats is not None:
                min_diag = qronos_stats.sigma_x_xhat.diag().abs().min().item()
                if min_diag < cfg.qronos_auto_skip_min_diag:
                    qronos_enabled = False
                    print(f"[qronos] {layer_id}.{weight}: min_diag(Σ_{{X,X̂}})={min_diag:.2e} < {cfg.qronos_auto_skip_min_diag:.2e}, skipping Qronos", flush=True)

        # Compute residual compensation stats for wo/w2 layers if enabled
        # Requires both unquant and quant models/caches (but not necessarily qronos mode)
        residual_stats = None
        residual_comp_enabled = (
            cfg.residual_compensation
            and weight.lower() in ("wo", "w2")
            and model_unquant is not None
            and cache_unquant is not None
            and layer_id >= cfg.rescomp_skip_prefix  # Skip first N layers
        )
        if cfg.residual_compensation and layer_id < cfg.rescomp_skip_prefix and weight.lower() in ("wo", "w2"):
            print(f"[residual] layer {layer_id} < rescomp_skip_prefix ({cfg.rescomp_skip_prefix}), skipping residual compensation", flush=True)
        if residual_comp_enabled:
            from quant_layerwise.hessian_runtime import compute_residual_stats_cached

            residual_stats = compute_residual_stats_cached(
                model_unquant=model_unquant,
                model_quant=model,
                cache_unquant=cache_unquant,
                cache_quant=cache,
                layer_id=layer_id,
                weight_type=weight.lower(),
                normalize=True,
                normalize_by="tokens",
                dtype=torch.float32,
                verbose=True,
                batch_size=cfg.hessian_batch_size,
            )
            print(f"[residual] computed Σ_{{ΔR,X̂}} for {weight} layer (shape={residual_stats.sigma_delta_r_xhat.shape})", flush=True)

        # Save Qronos stats (and residual stats if available) to pkl file
        if qronos_stats is not None:
            import pickle
            qronos_dir = run_dir / "qronos_stats"
            qronos_dir.mkdir(parents=True, exist_ok=True)
            qronos_path = qronos_dir / f"{safe_stem(module_name)}.pkl"
            # Save as simple dict: just load with pickle.load() and access keys
            qronos_dict = {
                "Sig_X": qronos_stats.sigma_x.cpu(),      # E[X X^T] - unquantized
                "Sig_hX": qronos_stats.sigma_xhat.cpu(),  # E[X̂ X̂^T] - quantized
                "Sig_X_hX": qronos_stats.sigma_x_xhat.cpu(),  # E[X X̂^T] - cross
                "module_name": module_name,
                "layer_id": layer_id,
            }
            # Include residual stats for wo/w2 layers when residual_compensation is enabled
            if residual_stats is not None:
                qronos_dict["Sig_delta_R_Xhat"] = residual_stats.sigma_delta_r_xhat.cpu()
                print(f"[qronos] including Σ_{{ΔR,X̂}} in stats (residual_compensation enabled)", flush=True)
            with open(qronos_path, "wb") as f:
                pickle.dump(qronos_dict, f)
            print(f"[qronos] saved stats to {qronos_path}", flush=True)

        # Determine if we need Qronos mode in ZSIC config
        # Only enable when qronos_enabled (full Qronos targeting)
        # Residual compensation can work without Qronos mode (simplified formula)
        use_qronos_mode = qronos_enabled

        if rate_ctrl is not None:
            if cfg.zsic is None:
                raise ValueError("rate_control enabled but cfg.zsic is None")
            target_x_used, rc_info = rate_ctrl.suggest_target_x(module_name)

            # Use the target rate from budget controller
            # Binary search (if enabled) will find the right internal target to achieve this rate
            zsic_cfg_layer = replace(
                cfg.zsic,
                target_rate_bits=float(target_x_used),
                qronos=use_qronos_mode,
                residual_compensation=residual_comp_enabled,
            )
            print(
                "[rate] wtype={} target={:.4f} remaining_budget={:.2f}".format(
                    rc_info.get("weight_type", "?"),
                    float(target_x_used),
                    float(rc_info.get("remaining_budget_bits", 0)),
                ),
                flush=True,
            )
        elif use_qronos_mode and cfg.zsic is not None:
            # Enable qronos mode in config (for Qronos targeting)
            zsic_cfg_layer = replace(
                cfg.zsic,
                qronos=True,
                residual_compensation=residual_comp_enabled,
            )
        elif residual_comp_enabled and cfg.zsic is not None:
            # Residual compensation without Qronos (simplified formula)
            zsic_cfg_layer = replace(
                cfg.zsic,
                qronos=False,
                residual_compensation=True,
            )

        # Quantize layer and build artifact.
        # Pass Qronos stats only when qronos_enabled (actual Qronos targeting)
        # Residual compensation works independently - just needs Σ_{ΔR,X̂}
        art = quantize_one_layer(
            model=model,
            module_name=module_name,
            weight_name=weight_name,
            H=H,  # Standard Hessian (used as fallback in non-Qronos mode)
            cfg=cfg,
            zsic_cfg=zsic_cfg_layer,
            Sig_X=qronos_stats.sigma_x if qronos_enabled and qronos_stats is not None else None,
            Sig_hX=qronos_stats.sigma_xhat if qronos_enabled and qronos_stats is not None else None,
            Sig_X_hX=qronos_stats.sigma_x_xhat if qronos_enabled and qronos_stats is not None else None,
            Sig_delta_R_Xhat=residual_stats.sigma_delta_r_xhat if residual_comp_enabled and residual_stats is not None else None,
        )

        # Save artifact.
        relpath = _artifact_relpath(module_name, method=art.method)
        art.save(run_dir / relpath)

        # Update manifest.
        manifest.add(module_name, relpath)
        manifest.save(manifest_path)

        # Apply to model so subsequent Hessians see it.
        apply_layer_artifact(model, art)

        # Update rate controller with achieved rate.
        actual_rate = float(art.payload.get("entropy", 0.0)) + float(art.payload.get("rate_overhead", 0.0))
        if rate_ctrl is not None:
            # When binary search was used, get the actual target it found
            bs_target = art.payload.get("binary_search_target_used", None)
            if bs_target is not None:
                rate_ctrl.update(module_name, target_x=float(bs_target), actual_rate=float(actual_rate))
            elif target_x_used is not None:
                rate_ctrl.update(module_name, target_x=float(target_x_used), actual_rate=float(actual_rate))
            rate_ctrl.save_json(str(run_dir / "rate_control_state.json"))

        # Append per-layer log line.
        line = {
            "module": module_name,
            "weight": weight_name,
            "shape": list(art.shape),
            "method": art.method,
            "nseq": int(nseq_used),
            "ntokens": int(ntokens_used),
            "actual_rate": float(actual_rate),
            "target_rate_bits_used": None if target_x_used is None else float(target_x_used),
            "rate_control": rc_info,
            "payload": {
                "loss": float(art.payload.get("loss", 0.0)),
                "entropy": float(art.payload.get("entropy", 0.0)),
                "rate_overhead": float(art.payload.get("rate_overhead", 0.0)),
                "qronos": bool(art.payload.get("qronos", False)),
                "binary_search_target_used": art.payload.get("binary_search_target_used", None),
                "binary_search_desired": art.payload.get("binary_search_desired", None),
                "binary_search_final_diff": art.payload.get("binary_search_final_diff", None),
            },
            "ts": time.time(),
        }
        with open(log_path, "a") as f:
            f.write(json.dumps(line) + "\n")

        # Advance cache through this block if we've processed all weights in it.
        # This propagates activations through the now-quantized block.
        weights_in_block = layer_weights_map.get(layer_id, [])
        is_last_weight_in_block = (weight == weights_in_block[-1]) if weights_in_block else False
        if is_last_weight_in_block and cache.current_block_idx == layer_id:
            print(f"[cache] advancing through block {layer_id} (all weights quantized)", flush=True)
            cache.advance_through_block(layer_id, batch_size=cfg.hessian_batch_size)
            # Qronos/unquant_hessians: also advance unquantized cache
            if cache_unquant is not None:
                print(f"[unquant-cache] advancing through block {layer_id}", flush=True)
                cache_unquant.advance_through_block(layer_id, batch_size=cfg.hessian_batch_size)

        torch.cuda.empty_cache()

    # Report resulting rate after the model (requested layers) is done.
    rate_summary = compute_run_rate_summary(run_dir, manifest)
    (run_dir / "rate_summary.json").write_text(json.dumps(rate_summary, indent=2))
    print(
        "\n[done] run_dir={}  avg_rate_bits_per_param={:.4f}  total_params={}".format(
            run_dir,
            float(rate_summary.get("avg_rate_bits_per_param", float("nan"))),
            int(rate_summary.get("total_params", 0)),
        )
    )

    # Generate activation MSE plot if requested and stats were collected
    if cfg.plot_activation_mse and (cfg.qronos or cfg.collect_qronos_stats):
        qronos_dir = run_dir / "qronos_stats"
        if qronos_dir.exists() and any(qronos_dir.glob("*.pkl")):
            try:
                from scripts.plot_activation_mse import plot_activation_mse
                plot_path = run_dir / "activation_mse.png"
                print(f"\n[plot] generating activation MSE plot...", flush=True)
                plot_activation_mse(
                    run_dir=run_dir,
                    output_path=plot_path,
                    title=f"Activation Drift: {cfg.model_name} ({cfg.method}, r={rate_summary.get('avg_rate_bits_per_param', 0):.2f})",
                    show_correlation=False,  # Single panel scatter plot
                )
                print(f"[plot] saved to {plot_path}", flush=True)
            except Exception as e:
                print(f"[plot] warning: failed to generate plot: {e}", flush=True)
        else:
            print(f"[plot] no qronos_stats found, skipping activation MSE plot", flush=True)

    return run_dir


def compute_run_rate_summary(run_dir: Path, manifest: RunManifest) -> Dict[str, Any]:
    total_bits = 0.0
    total_params = 0
    per_layer = []

    for module_name in sorted(manifest.artifacts.keys()):
        relpath = manifest.artifacts[module_name]
        art = LayerArtifact.load(run_dir / relpath, map_location="cpu")

        a, n = (int(art.shape[0]), int(art.shape[1]))
        numel = a * n

        entropy = float(art.payload.get("entropy", 0.0))
        overhead = float(art.payload.get("rate_overhead", 0.0))
        rate = float(entropy + overhead)

        total_bits += rate * float(numel)
        total_params += int(numel)

        per_layer.append(
            {
                "module": module_name,
                "method": art.method,
                "shape": [a, n],
                "entropy": entropy,
                "rate_overhead": overhead,
                "rate": rate,
                "numel": int(numel),
            }
        )

    avg_rate = total_bits / float(total_params) if total_params > 0 else None

    return {
        "model_name": manifest.model_name,
        "method": manifest.method,
        "run_id": manifest.run_id,
        "n_layers_quantized": int(len(manifest.artifacts)),
        "total_params": int(total_params),
        "total_bits": float(total_bits),
        "avg_rate_bits_per_param": None if avg_rate is None else float(avg_rate),
        "per_layer": per_layer,
    }


def build_layers(
    *,
    layer_ids: Iterable[int],
    weights: Sequence[str] = ("wq", "wk", "wv", "wo", "w1", "w2", "w3"),
) -> List[Tuple[int, str]]:
    out: List[Tuple[int, str]] = []
    for lid in layer_ids:
        for w in weights:
            out.append((int(lid), str(w)))
    return out
