"""Run one layerwise quantization job.

This script is designed to be used in two ways:
  1) CLI: `python -m scripts.run_pipeline_job ...`
  2) As a callable from your GPU scheduler (mimicking your baseline runners):
        from scripts.run_pipeline_job import run_pipeline_job

If you're running without torchrun (WORLD_SIZE=1), you can pass `init_dist=True`
so that `parallel.start.start` can initialize NCCL using a per-GPU master port.
This only works when your ckpt_dir has a *single* checkpoint shard.
"""

from __future__ import annotations

import argparse
import os
from typing import List, Tuple

import torch

from quant_layerwise.pipeline import (
    PipelineConfig,
    build_layers,
    ensure_single_process_distributed,
    run_pipeline,
)
from quant_layerwise.rate_control import RateControlConfig
from quant_layerwise.methods.gptq import GPTQConfig
from quant_layerwise.methods.zsic import ZSICConfig


def run_pipeline_job(
    model_name: str,
    method: str,
    target_rate: float,
    *,
    # layer selection
    layer_begin: int = 0,
    layer_end: int = 32,
    weights: str = "wq,wk,wv,wo,w1,w2,w3",
    # calib
    seqlen: int = 2048,
    calib_nsamples: int | None = None,  # None means use all available samples
    hessian_batch_size: int = 1,  # Batch size for Hessian computation (higher = faster)
    # hadamard
    hadamard: bool = False,
    hadamard_type: str = "row",  # "none", "row", "column", "row_column"
    hadamard_seed: int = 0,
    # GPTQ
    groupsize: int = -1,
    blocksize: int = 128,
    percdamp: float = 0.0,
    actorder: bool = False,
    gptq_maxq: int | None = None,  # Override maxq (if None, compute from target_rate)
    overhead_bits_per_param: int = 16,
    unquant_hessians: bool = False,  # Use Hessians from unquantized model
    # ZSIC
    zsic_apply_tgamma: bool = True,
    zsic_tgamma_ridge: float = 0.0,
    zsic_tgamma_max_iter: int = 500,
    zsic_tgamma_tol: float = 1e-3,
    zsic_percdamp: float | None = None,
    zsic_binary_search: bool = False,
    zsic_binary_search_iters: int = 20,
    zsic_binary_search_row_fraction: float = 0.1,
    # Rate control (mainly for ZSIC)
    rate_control: bool = False,
    global_rate_bits: float | None = None,
    rate_xmin: float = 0.05,
    rate_xmax: float = 16.0,
    rate_weight_budgets: str = "",  # Format: "wk:1.5,wq:1.25" to give wk 50% more bits, wq 25% more
    # Skip quantization for specific layers (store full precision)
    skip_quantize: str = "",  # Format: "0.wq,0.wk,1.wq,1.wk" to skip layers 0-1 wq/wk
    # Qronos: compute and save Σ_X̂ and Σ_XX̂ statistics
    qronos: bool = False,
    # Qronos layer range: only apply Qronos targeting to layers in [min, max)
    qronos_layer_min: int | None = None,
    qronos_layer_max: int | None = None,
    # Skip Qronos targeting for specific (layer_id, weight) pairs
    qronos_skip_layers: str = "",  # e.g., "2.wo,3.wo" to skip L2_wo and L3_wo
    # Skip Qronos targeting for specific weight types globally (all layers)
    qronos_skip_weights: str = "",  # e.g., "wq,wk,wv" to skip all Q/K/V
    qronos_skip_qkv_prefix: int = 0,  # Skip Qronos for wq/wk/wv in first N layers
    qronos_auto_skip_min_diag: float = 0.0,  # Auto-skip Qronos if min(diag(Σ_{X,X̂})) < threshold
    # Collect Qronos stats for diagnostics (no targeting) + plot activation MSE
    collect_qronos_stats: bool = False,
    plot_activation_mse: bool = False,
    # Residual stream compensation for wo/w2 layers (requires qronos=True)
    residual_compensation: bool = False,
    # Skip residual compensation on the first N layers (0 = apply to all)
    rescomp_skip_prefix: int = 0,
    # output
    run_root: str = "quant_runs",
    run_id: str = "",
    resume: bool = True,
    # distributed init (single process)
    init_dist: bool = False,
    master_port_base: int = 29500,
    local_rank: int | None = None,
):
    if local_rank is None:
        # Prefer CUDA device (scheduler sets it). Fallback to LOCAL_RANK env.
        if torch.cuda.is_available():
            local_rank = int(torch.cuda.current_device())
        else:
            local_rank = int(os.environ.get("LOCAL_RANK", 0))

    if init_dist:
        # Use a different port per GPU to avoid collisions when multiple single-process
        # jobs run on the same machine.
        ensure_single_process_distributed(local_rank=local_rank, master_port=master_port_base + int(local_rank))

    wlist = [w.strip() for w in weights.split(",") if w.strip()]
    layers: List[Tuple[int, str]] = build_layers(layer_ids=range(int(layer_begin), int(layer_end)), weights=wlist)

    method_l = method.lower()
    gptq_cfg = None
    zsic_cfg = None
    rate_cfg = None

    if method_l == "gptq":
        gptq_cfg = GPTQConfig(
            target_rate=float(target_rate),
            groupsize=int(groupsize),
            blocksize=int(blocksize),
            percdamp=float(percdamp),
            actorder=bool(actorder),
            overhead_bits_per_param=int(overhead_bits_per_param),
            maxq=int(gptq_maxq) if gptq_maxq is not None else None,
        )
    elif method_l in ("zsic", "sic"):
        # For ZSIC, use percdamp=0.0001 by default for numerical stability
        # (especially important for Qronos mode with Sigma_hatX)
        if zsic_percdamp is None:
            if percdamp > 0.0:
                zsic_percdamp = float(percdamp)
            else:
                zsic_percdamp = 0.0001  # Safe default for ZSIC
        zsic_cfg = ZSICConfig(
            target_rate_bits=float(target_rate),
            sic_variant="compress_w2q",
            apply_tgamma=bool(zsic_apply_tgamma),
            tgamma_ridge=float(zsic_tgamma_ridge),
            tgamma_max_iter=int(zsic_tgamma_max_iter),
            tgamma_tol=float(zsic_tgamma_tol),
            overhead_bits_per_param=int(overhead_bits_per_param),
            percdamp=float(zsic_percdamp),
            binary_search=bool(zsic_binary_search),
            binary_search_iters=int(zsic_binary_search_iters),
            binary_search_row_fraction=float(zsic_binary_search_row_fraction),
            qronos=bool(qronos),  # Enable Qronos mode if --qronos flag is set
        )

        # Rate control: if enabled, we try to hit a *global* average rate.
        if bool(rate_control) or global_rate_bits is not None:
            g = float(target_rate) if global_rate_bits is None else float(global_rate_bits)

            # Parse weight budget multipliers: "wk:1.5,wq:1.25" -> {"wk": 1.5, "wq": 1.25}
            weight_mults = None
            if rate_weight_budgets:
                weight_mults = {}
                for item in rate_weight_budgets.split(","):
                    item = item.strip()
                    if not item:
                        continue
                    parts = item.split(":")
                    if len(parts) != 2:
                        raise ValueError(f"Invalid rate_weight_budgets format: '{item}'. Expected 'weight:multiplier' (e.g., 'wk:1.5')")
                    wtype = parts[0].strip()
                    mult = float(parts[1].strip())
                    weight_mults[wtype] = mult

            rate_cfg = RateControlConfig(
                enabled=True,
                global_target_rate_bits=float(g),
                xmin=float(rate_xmin),
                xmax=float(rate_xmax),
                weight_budget_multipliers=weight_mults,
            )
    else:
        raise ValueError(f"Unknown method: {method}")

    # Parse skip_quantize: "0.wq,0.wk,1.wq,1.wk" -> [(0, "wq"), (0, "wk"), ...]
    skip_layers: List[Tuple[int, str]] = []
    if skip_quantize:
        for item in skip_quantize.split(","):
            item = item.strip()
            if not item:
                continue
            parts = item.split(".")
            if len(parts) != 2:
                raise ValueError(f"Invalid skip_quantize format: '{item}'. Expected 'layer_id.weight' (e.g., '0.wq')")
            layer_id = int(parts[0])
            weight = parts[1].strip()
            skip_layers.append((layer_id, weight))

    # Parse qronos_skip_layers: "2.wo,3.wo" -> [(2, "wo"), (3, "wo")]
    qronos_skip: List[Tuple[int, str]] = []
    if qronos_skip_layers:
        for item in qronos_skip_layers.split(","):
            item = item.strip()
            if not item:
                continue
            parts = item.split(".")
            if len(parts) != 2:
                raise ValueError(f"Invalid qronos_skip_layers format: '{item}'. Expected 'layer_id.weight' (e.g., '2.wo')")
            layer_id = int(parts[0])
            weight = parts[1].strip()
            qronos_skip.append((layer_id, weight))

    # Parse qronos_skip_weights: "wq,wk,wv" -> ["wq", "wk", "wv"]
    qronos_skip_w: List[str] = []
    if qronos_skip_weights:
        for item in qronos_skip_weights.split(","):
            item = item.strip()
            if item:
                qronos_skip_w.append(item)

    cfg = PipelineConfig(
        model_name=str(model_name),
        method=str(method_l),
        layers=layers,
        seqlen=int(seqlen),
        calib_nsamples=None if calib_nsamples is None else int(calib_nsamples),
        hessian_batch_size=int(hessian_batch_size),
        hadamard=bool(hadamard),
        hadamard_type=str(hadamard_type),
        hadamard_seed=int(hadamard_seed),
        run_root=str(run_root),
        run_id=str(run_id),
        resume=bool(resume),
        gptq=gptq_cfg,
        zsic=zsic_cfg,
        rate_control=rate_cfg,
        skip_quantize_layers=skip_layers,
        qronos=bool(qronos),
        qronos_layer_min=qronos_layer_min,
        qronos_layer_max=qronos_layer_max,
        qronos_skip_layers=qronos_skip,
        qronos_skip_weights=qronos_skip_w,
        qronos_skip_qkv_prefix=int(qronos_skip_qkv_prefix),
        qronos_auto_skip_min_diag=float(qronos_auto_skip_min_diag),
        collect_qronos_stats=bool(collect_qronos_stats),
        plot_activation_mse=bool(plot_activation_mse),
        unquant_hessians=bool(unquant_hessians),
        residual_compensation=bool(residual_compensation),
        rescomp_skip_prefix=int(rescomp_skip_prefix),
    )

    return run_pipeline(cfg, local_rank=int(local_rank))


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--model", required=True)
    p.add_argument("--method", required=True, choices=["gptq", "zsic"])
    p.add_argument("--target_rate", required=True, type=float)

    p.add_argument("--layer_begin", type=int, default=0)
    p.add_argument("--layer_end", type=int, default=32)
    p.add_argument("--weights", type=str, default="wq,wk,wv,wo,w1,w2,w3")

    p.add_argument("--seqlen", type=int, default=2048)
    p.add_argument("--calib_nsamples", type=int, default=None, help="Number of calibration samples (default: all available)")
    p.add_argument("--hessian_batch_size", type=int, default=1, help="Batch size for Hessian computation (higher = faster, more memory)")

    p.add_argument("--hadamard", action="store_true")
    p.add_argument("--hadamard_type", type=str, default="row",
                   choices=["none", "row", "column", "row_column"],
                   help="Type of Hadamard transform: row (W@H), column (H@W), row_column (both)")
    p.add_argument("--hadamard_seed", type=int, default=0)

    # GPTQ
    p.add_argument("--groupsize", type=int, default=-1)
    p.add_argument("--blocksize", type=int, default=128)
    p.add_argument("--percdamp", type=float, default=0.0)
    p.add_argument("--actorder", action="store_true")
    p.add_argument("--gptq_maxq", type=int, default=None,
                   help="Override maxq for GPTQ (default: compute from target_rate as 2^(rate+1)-1)")

    # ZSIC
    p.add_argument("--zsic_apply_tgamma", action="store_true")
    p.add_argument("--no_zsic_apply_tgamma", dest="zsic_apply_tgamma", action="store_false")
    p.set_defaults(zsic_apply_tgamma=True)
    p.add_argument("--zsic_tgamma_ridge", type=float, default=0.0)
    p.add_argument("--zsic_tgamma_max_iter", type=int, default=500)
    p.add_argument("--zsic_tgamma_tol", type=float, default=1e-3)

    # Optional: override percdamp used inside ZSIC (defaults to --percdamp)
    p.add_argument("--zsic_percdamp", type=float, default=None)

    # Optional: binary search to hit target rate more precisely
    p.add_argument("--zsic_binary_search", action="store_true",
                   help="Binary search for target_rate that achieves desired actual rate")
    p.add_argument("--zsic_binary_search_iters", type=int, default=20,
                   help="Number of binary search iterations")
    p.add_argument("--zsic_binary_search_row_fraction", type=float, default=0.1,
                   help="Fraction of rows to use during binary search (e.g., 0.1 for 10%%). Default 0.1 for speed.")

    # Optional: rate control (mainly for ZSIC)
    p.add_argument("--rate_control", action="store_true", help="enable global rate budget tracking")
    p.add_argument("--global_rate_bits", type=float, default=None, help="global avg target bits/param (default: --target_rate)")
    p.add_argument("--rate_xmin", type=float, default=0.05, help="minimum allowed target rate")
    p.add_argument("--rate_xmax", type=float, default=16.0, help="maximum allowed target rate")
    p.add_argument("--rate_weight_budgets", type=str, default="",
                   help="Weight-type budget multipliers. Format: 'wk:1.5,wq:1.25' gives wk 50%% more bits, wq 25%% more")

    p.add_argument("--skip_quantize", type=str, default="",
                   help="Skip quantization for specific layers (store full precision). Format: '0.wq,0.wk,1.wq,1.wk'")

    p.add_argument("--qronos", action="store_true",
                   help="Compute and save Qronos statistics (Σ_X̂ and Σ_XX̂) for each layer")
    p.add_argument("--qronos_layer_min", type=int, default=None,
                   help="Only apply Qronos targeting to layers >= this (default: all layers)")
    p.add_argument("--qronos_layer_max", type=int, default=None,
                   help="Only apply Qronos targeting to layers < this (default: all layers)")
    p.add_argument("--qronos_skip_layers", type=str, default="",
                   help="Skip Qronos targeting for specific (layer.weight) pairs. E.g., '2.wo,3.wo' to skip L2_wo and L3_wo")
    p.add_argument("--qronos_skip_weights", type=str, default="",
                   help="Skip Qronos targeting for specific weight types globally. E.g., 'wq,wk,wv' to skip all Q/K/V")
    p.add_argument("--qronos_skip_qkv_prefix", type=int, default=0,
                   help="Skip Qronos for wq/wk/wv in the first N layers (default: 0 = no skip)")
    p.add_argument("--qronos_auto_skip_min_diag", type=float, default=0.0,
                   help="Auto-skip Qronos if min(diag(Σ_{X,X̂})) < threshold (default: 0 = disabled, recommended: 1e-5)")

    p.add_argument("--collect_qronos_stats", action="store_true",
                   help="Collect Qronos stats for diagnostics (no Qronos targeting)")
    p.add_argument("--plot_activation_mse", action="store_true",
                   help="Plot activation MSE at end of run (requires --qronos or --collect_qronos_stats)")

    p.add_argument("--residual_compensation", action="store_true",
                   help="Enable residual stream compensation for wo/w2 layers (automatically enables Qronos mode for wo/w2)")
    p.add_argument("--rescomp_skip_prefix", type=int, default=0,
                   help="Skip residual compensation on the first N layers (0 = apply to all)")

    p.add_argument("--unquant_hessians", action="store_true",
                   help="Use Hessians from unquantized model (avoids error propagation)")

    p.add_argument("--overhead_bits_per_param", type=int, default=16)

    p.add_argument("--run_root", type=str, default="quant_runs")
    p.add_argument("--run_id", type=str, default="")
    p.add_argument("--resume", action="store_true")
    p.add_argument("--no_resume", dest="resume", action="store_false")
    p.set_defaults(resume=True)

    p.add_argument("--init_dist", action="store_true", help="init NCCL env for single-process jobs")
    p.add_argument("--master_port_base", type=int, default=29500)

    args = p.parse_args()

    run_pipeline_job(
        args.model,
        args.method,
        args.target_rate,
        layer_begin=args.layer_begin,
        layer_end=args.layer_end,
        weights=args.weights,
        seqlen=args.seqlen,
        calib_nsamples=args.calib_nsamples,
        hessian_batch_size=args.hessian_batch_size,
        hadamard=args.hadamard,
        hadamard_type=args.hadamard_type,
        hadamard_seed=args.hadamard_seed,
        groupsize=args.groupsize,
        blocksize=args.blocksize,
        percdamp=args.percdamp,
        actorder=args.actorder,
        gptq_maxq=args.gptq_maxq,
        overhead_bits_per_param=args.overhead_bits_per_param,
        zsic_apply_tgamma=args.zsic_apply_tgamma,
        zsic_tgamma_ridge=args.zsic_tgamma_ridge,
        zsic_tgamma_max_iter=args.zsic_tgamma_max_iter,
        zsic_tgamma_tol=args.zsic_tgamma_tol,
        zsic_percdamp=args.zsic_percdamp,
        zsic_binary_search=args.zsic_binary_search,
        zsic_binary_search_iters=args.zsic_binary_search_iters,
        zsic_binary_search_row_fraction=args.zsic_binary_search_row_fraction,
        rate_control=args.rate_control,
        global_rate_bits=args.global_rate_bits,
        rate_xmin=args.rate_xmin,
        rate_xmax=args.rate_xmax,
        rate_weight_budgets=args.rate_weight_budgets,
        skip_quantize=args.skip_quantize,
        qronos=args.qronos,
        qronos_layer_min=args.qronos_layer_min,
        qronos_layer_max=args.qronos_layer_max,
        qronos_skip_layers=args.qronos_skip_layers,
        qronos_skip_weights=args.qronos_skip_weights,
        qronos_skip_qkv_prefix=args.qronos_skip_qkv_prefix,
        qronos_auto_skip_min_diag=args.qronos_auto_skip_min_diag,
        collect_qronos_stats=args.collect_qronos_stats,
        plot_activation_mse=args.plot_activation_mse,
        residual_compensation=args.residual_compensation,
        rescomp_skip_prefix=args.rescomp_skip_prefix,
        unquant_hessians=args.unquant_hessians,
        run_root=args.run_root,
        run_id=args.run_id,
        resume=args.resume,
        init_dist=args.init_dist,
        master_port_base=args.master_port_base,
    )


if __name__ == "__main__":
    main()
