"""Run quantization sweep over multiple rates.

Usage:
    python -m scripts.run_quant_sweep --model 3-8B --method zsic --rate_min 0.5 --rate_max 3.5 --rate_step 0.5
    python -m scripts.run_quant_sweep --model 3.2-1B --method gptq --rate_min 1.0 --rate_max 4.0 --rate_step 0.5

Outputs a sweep manifest file that can be used by run_eval_sweep.py:
    quant_runs/{model}/sweeps/sweep_{method}_{timestamp}.json
"""

from __future__ import annotations

import argparse
import json
import multiprocessing as mp
import random
import socket
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import GPUtil


def find_free_port(start: int = 30000, end: int = 65000) -> int:
    """Find a free port in the given range."""
    # Randomize starting point to reduce collisions between concurrent sweeps
    port = random.randint(start, end)
    for _ in range(end - start):
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(("127.0.0.1", port))
                return port
        except OSError:
            port = start + (port - start + 1) % (end - start)
    raise RuntimeError(f"Could not find a free port in range {start}-{end}")

# NOTE: We intentionally avoid importing torch or any module that imports torch
# at module load time. This ensures CUDA_VISIBLE_DEVICES isolation works correctly
# in spawned processes.


def _get_bucket_path() -> Path:
    """Get bucket path from environment (inlined to avoid torch import)."""
    import os
    p = os.environ.get("QUANT_BUCKET", None)
    if not p:
        raise RuntimeError(
            "QUANT_BUCKET environment variable is not set. "
            "Example: export QUANT_BUCKET=/path/to/quant-bucket"
        )
    return Path(p)


# Model configs: model_name -> num_layers
MODEL_CONFIGS = {
    "3.2-1B": 16,
    "2-7B": 32,
    "3-8B": 32,
    "2-13B": 40,
    "3-70B": 80,
    "qwen3-8B": 36,
}


def generate_sweep_id(method: str, *, qronos: bool = False, residual_compensation: bool = False, hadamard: bool = False, hadamard_type: str = "row") -> str:
    """Generate a unique sweep ID."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    parts = ["sweep", method]
    if qronos:
        parts.append("qronos")
    if residual_compensation:
        parts.append("rescomp")
    if hadamard and hadamard_type != "none":
        parts.append(f"had_{hadamard_type}")
    parts.append(timestamp)
    return "_".join(parts)


def _worker_fn(gpu: int, task_id: int, port_base: int, func_module: str, func_name: str, args: tuple, kwargs: dict):
    """Worker function that runs in spawned process."""
    import importlib
    import os
    import sys
    import traceback

    # CRITICAL: Set CUDA_VISIBLE_DEVICES so that this process only sees one GPU
    # This makes "cuda:0" in this process map to the actual GPU we want.
    # This is necessary because parallel/start.py does torch.cuda.set_device(LOCAL_RANK)
    # and silences output for LOCAL_RANK > 0.
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)

    # Force unbuffered output for spawned processes
    os.environ["PYTHONUNBUFFERED"] = "1"
    sys.stdout.reconfigure(line_buffering=True)
    sys.stderr.reconfigure(line_buffering=True)

    # Compute task-specific port
    task_port = port_base + task_id * 10
    print(f"[worker GPU {gpu}, task {task_id}] starting (CUDA_VISIBLE_DEVICES={gpu}, port={task_port})", flush=True)

    exit_code = 0
    try:
        import torch
        # Now cuda:0 maps to our target GPU
        torch.cuda.set_device(0)

        # Print GPU diagnostics to confirm correct assignment
        device_count = torch.cuda.device_count()
        device_name = torch.cuda.get_device_name(0)
        print(f"[worker GPU {gpu}, task {task_id}] torch sees {device_count} device(s), using: {device_name}", flush=True)

        # Override master_port_base in kwargs to use task-specific port
        # This avoids NCCL port collisions between tasks
        kwargs["master_port_base"] = task_port

        module = importlib.import_module(func_module)
        func = getattr(module, func_name)
        func(*args, **kwargs)
        print(f"[worker GPU {gpu}, task {task_id}] completed successfully", flush=True)
    except Exception as e:
        # Print error to both stdout and stderr to ensure visibility
        error_msg = f"\n{'='*60}\n[worker GPU {gpu}, task {task_id}] ERROR: {e}\n{'='*60}\n"
        print(error_msg, flush=True)
        print(error_msg, file=sys.stderr, flush=True)
        traceback.print_exc(file=sys.stdout)
        sys.stdout.flush()
        traceback.print_exc(file=sys.stderr)
        sys.stderr.flush()
        exit_code = 1
    finally:
        # Clean up distributed process group and CUDA resources
        try:
            import torch.distributed as dist
            if dist.is_initialized():
                dist.destroy_process_group()
                print(f"[worker GPU {gpu}, task {task_id}] destroyed process group", flush=True)
        except Exception:
            pass  # Ignore cleanup errors

        try:
            import torch
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        except Exception:
            pass  # Ignore cleanup errors

    if exit_code != 0:
        sys.exit(exit_code)


def run_task(task, gpu: int, task_id: int, port_base: int):
    """Run a task on a specific GPU using spawn-safe approach."""
    func_module, func_name, args, kwargs = task

    ctx = mp.get_context("spawn")
    p = ctx.Process(target=_worker_fn, args=(gpu, task_id, port_base, func_module, func_name, args, kwargs))
    p.start()
    return p


def run_tasks(tasks: List, gpu_list: Optional[List[int]] = None):
    """Run tasks on available GPUs using a simple scheduler."""
    if gpu_list is None:
        gpu_list = GPUtil.getAvailable(
            order="first",
            maxLoad=0.05,
            maxMemory=0.05,
            limit=1000,
        )
    if not gpu_list:
        raise RuntimeError("No free GPUs available")

    # Find a free port range for this sweep (avoids collisions with previous crashed processes)
    # Each task needs ~10 ports, so reserve enough for all tasks
    port_base = find_free_port(start=30000, end=60000)
    print(f"Detected free GPUs: {gpu_list}")
    print(f"Using port base: {port_base} (range {port_base}-{port_base + len(tasks) * 10})")

    # Track: gpu_id -> (process, task_id)
    gpu_status: Dict[int, Optional[Tuple[mp.Process, int]]] = {i: None for i in gpu_list}
    next_task = 0
    finished_tasks = 0
    failed_tasks: List[int] = []

    def cleanup_all():
        """Terminate and join all running processes."""
        for gpu_id, status in gpu_status.items():
            if status is not None:
                proc, tid = status
                if proc.is_alive():
                    print(f"Terminating task {tid} on GPU {gpu_id}...")
                    proc.terminate()
                    proc.join(timeout=5)
                    if proc.is_alive():
                        print(f"Force killing task {tid} on GPU {gpu_id}...")
                        proc.kill()
                        proc.join(timeout=2)
                else:
                    proc.join(timeout=1)  # Reap zombie

    try:
        while finished_tasks < len(tasks):
            for gpu_id in list(gpu_status.keys()):
                status = gpu_status[gpu_id]
                if status is not None:
                    proc, task_id = status
                    if not proc.is_alive():
                        # Process finished - join it to clean up resources
                        proc.join(timeout=5)
                        exitcode = proc.exitcode
                        if exitcode != 0:
                            print(f"WARNING: Task {task_id} on GPU {gpu_id} failed with exitcode {exitcode}")
                            failed_tasks.append(task_id)
                        else:
                            print(f"Task {task_id} on GPU {gpu_id} completed successfully")
                        gpu_status[gpu_id] = None
                        finished_tasks += 1

                if gpu_status[gpu_id] is None and next_task < len(tasks):
                    print(f"Allocating task {next_task}/{len(tasks)} to GPU {gpu_id}")
                    proc = run_task(tasks[next_task], gpu_id, task_id=next_task, port_base=port_base)
                    gpu_status[gpu_id] = (proc, next_task)
                    next_task += 1
                    # Don't start another task in the same loop iteration
                    # Wait for model loading before starting more tasks
                    break
            time.sleep(0.5)

    except KeyboardInterrupt:
        print("\nInterrupted! Cleaning up processes...")
        cleanup_all()
        raise

    # Final cleanup - join any remaining processes
    for gpu_id, status in gpu_status.items():
        if status is not None:
            proc, _ = status
            proc.join(timeout=5)

    if failed_tasks:
        print(f"\nWARNING: {len(failed_tasks)} task(s) failed: {failed_tasks}")
    print(f"All {len(tasks)} tasks processed ({len(tasks) - len(failed_tasks)} succeeded, {len(failed_tasks)} failed)")


def make_run_id(model: str, method: str, rate: float, *, qronos: bool = False, residual_compensation: bool = False, rescomp_skip_prefix: int = 0, hadamard: bool = False, hadamard_type: str = "row", rate_weight_budgets: str = "", qronos_skip_layers: str = "", qronos_skip_weights: str = "", qronos_skip_qkv_prefix: int = 0, qronos_auto_skip_min_diag: float = 0.0) -> str:
    """Generate a run ID for a quantization run."""
    parts = [model, method]
    if method == "zsic" and qronos:
        parts.append("qronos")
    if method == "zsic" and residual_compensation:
        if rescomp_skip_prefix > 0:
            # e.g., "rescomp_from8" means skip layers 0-7, apply from layer 8+
            parts.append(f"rescomp_from{rescomp_skip_prefix}")
        else:
            parts.append("rescomp")
        # Add skip info (e.g., "skip2wo_3wo" for skipping L2_wo and L3_wo)
        if qronos_skip_layers:
            skip_items = []
            for item in qronos_skip_layers.split(","):
                item = item.strip()
                if item:
                    # "2.wo" -> "2wo"
                    skip_items.append(item.replace(".", ""))
            if skip_items:
                parts.append("skip" + "_".join(skip_items))
        # Add skip weights info (e.g., "skipw_wq_wk_wv" for skipping all wq/wk/wv)
        if qronos_skip_weights:
            skip_w_items = [w.strip() for w in qronos_skip_weights.split(",") if w.strip()]
            if skip_w_items:
                parts.append("skipw_" + "_".join(skip_w_items))
        # Add QKV prefix skip (e.g., "qkvskip4" for skipping first 4 layers)
        if qronos_skip_qkv_prefix > 0:
            parts.append(f"qkvskip{qronos_skip_qkv_prefix}")
        # Add auto-skip flag (e.g., "autoskip1e-5" for threshold 1e-5)
        if qronos_auto_skip_min_diag > 0:
            parts.append(f"autoskip{qronos_auto_skip_min_diag:.0e}")
    if hadamard and hadamard_type != "none":
        parts.append(f"had_{hadamard_type}")
    # Add weight budget info to run_id (e.g., "wb_wo_w2" for wo and w2 budgets)
    if rate_weight_budgets:
        # Parse "wo:1.15,w2:1.15" -> "wb_wo_w2"
        weight_names = []
        for item in rate_weight_budgets.split(","):
            item = item.strip()
            if ":" in item:
                wname = item.split(":")[0].strip()
                weight_names.append(wname)
        if weight_names:
            parts.append("wb_" + "_".join(sorted(set(weight_names))))
    parts.append(f"r{rate:.2f}")
    return ".".join(parts)


def build_quant_tasks(
    model: str,
    method: str,
    rates: List[float],
    run_root: str,
    hessian_batch_size: int = 10,
    # GPTQ options
    groupsize: int = -1,
    percdamp: float = 0.1,
    gptq_maxq: int | None = None,
    unquant_hessians: bool = False,
    # ZSIC options
    qronos: bool = False,
    zsic_percdamp: float | None = None,
    qronos_layer_min: int | None = None,
    qronos_layer_max: int | None = None,
    qronos_skip_layers: str = "",  # e.g., "2.wo,3.wo" to skip L2_wo and L3_wo
    qronos_skip_weights: str = "",  # e.g., "wq,wk,wv" to skip all Q/K/V
    qronos_skip_qkv_prefix: int = 0,  # Skip Qronos for wq/wk/wv in first N layers
    qronos_auto_skip_min_diag: float = 0.0,  # Auto-skip if min(diag(Σ_{X,X̂})) < threshold
    rate_weight_budgets: str = "",  # e.g., "wo:1.15,w2:1.15"
    # Diagnostics: collect stats and plot activation MSE
    collect_qronos_stats: bool = False,
    plot_activation_mse: bool = False,
    # Residual stream compensation for wo/w2 layers
    residual_compensation: bool = False,
    rescomp_skip_prefix: int = 0,  # Skip first N layers
    # Hadamard options
    hadamard: bool = False,
    hadamard_type: str = "row",
    hadamard_seed: int = 0,
    # Resume
    resume: bool = True,
) -> Tuple[List[Tuple], List[Dict[str, Any]]]:
    """Build list of quantization tasks.

    Returns:
        tasks: List of (func_module, func_name, args, kwargs) tuples
        run_infos: List of dicts with run metadata for sweep manifest
    """
    if model not in MODEL_CONFIGS:
        raise ValueError(f"Unknown model: {model}. Supported: {list(MODEL_CONFIGS.keys())}")

    num_layers = MODEL_CONFIGS[model]
    tasks = []
    run_infos = []

    # Get bucket path for run_dir
    bucket = _get_bucket_path()

    # Use string references to avoid importing torch in main process
    func_module = "scripts.run_pipeline_job"
    func_name = "run_pipeline_job"

    for rate in rates:
        run_id = make_run_id(model, method, rate, qronos=qronos, residual_compensation=residual_compensation, rescomp_skip_prefix=rescomp_skip_prefix, hadamard=hadamard, hadamard_type=hadamard_type, rate_weight_budgets=rate_weight_budgets, qronos_skip_layers=qronos_skip_layers, qronos_skip_weights=qronos_skip_weights, qronos_skip_qkv_prefix=qronos_skip_qkv_prefix, qronos_auto_skip_min_diag=qronos_auto_skip_min_diag)
        run_dir = str(bucket / run_root / model / run_id)

        kwargs = {
            "model_name": model,
            "method": method,
            "target_rate": rate,
            "layer_begin": 0,
            "layer_end": num_layers,
            "hessian_batch_size": hessian_batch_size,
            "hadamard": hadamard,
            "hadamard_type": hadamard_type,
            "hadamard_seed": hadamard_seed,
            "run_root": run_root,
            "run_id": run_id,
            "resume": resume,
            "init_dist": True,
        }

        # Diagnostics options (apply to all methods)
        if collect_qronos_stats:
            kwargs["collect_qronos_stats"] = True
        if plot_activation_mse:
            kwargs["plot_activation_mse"] = True

        # Method-specific params
        if method == "zsic":
            kwargs.update({
                "zsic_binary_search": True,
                "rate_control": True,
                "qronos": qronos,
                "residual_compensation": residual_compensation,
                "rescomp_skip_prefix": rescomp_skip_prefix,
            })
            if zsic_percdamp is not None:
                kwargs["zsic_percdamp"] = zsic_percdamp
            if qronos_layer_min is not None:
                kwargs["qronos_layer_min"] = qronos_layer_min
            if qronos_layer_max is not None:
                kwargs["qronos_layer_max"] = qronos_layer_max
            if qronos_skip_layers:
                kwargs["qronos_skip_layers"] = qronos_skip_layers
            if qronos_skip_weights:
                kwargs["qronos_skip_weights"] = qronos_skip_weights
            if qronos_skip_qkv_prefix > 0:
                kwargs["qronos_skip_qkv_prefix"] = qronos_skip_qkv_prefix
            if qronos_auto_skip_min_diag > 0:
                kwargs["qronos_auto_skip_min_diag"] = qronos_auto_skip_min_diag
            if rate_weight_budgets:
                kwargs["rate_weight_budgets"] = rate_weight_budgets
        elif method == "gptq":
            kwargs.update({
                "percdamp": percdamp,
                "groupsize": groupsize,
                "unquant_hessians": unquant_hessians,
            })
            if gptq_maxq is not None:
                kwargs["gptq_maxq"] = gptq_maxq

        tasks.append((func_module, func_name, (), kwargs))
        run_infos.append({
            "rate": rate,
            "run_id": run_id,
            "run_dir": run_dir,
        })

    return tasks, run_infos


def save_sweep_manifest(
    sweep_id: str,
    model: str,
    method: str,
    rates: List[float],
    run_infos: List[Dict[str, Any]],
    run_root: str,
    *,
    groupsize: int = -1,
    percdamp: float = 0.1,
    gptq_maxq: int | None = None,
    qronos: bool = False,
    zsic_percdamp: float | None = None,
    qronos_layer_min: int | None = None,
    qronos_layer_max: int | None = None,
    qronos_skip_layers: str = "",
    qronos_skip_weights: str = "",
    qronos_skip_qkv_prefix: int = 0,
    qronos_auto_skip_min_diag: float = 0.0,
    rate_weight_budgets: str = "",
    collect_qronos_stats: bool = False,
    plot_activation_mse: bool = False,
    residual_compensation: bool = False,
    rescomp_skip_prefix: int = 0,
    hadamard: bool = False,
    hadamard_type: str = "row",
    hadamard_seed: int = 0,
) -> Path:
    """Save sweep manifest file to $QUANT_BUCKET/run_root/model/sweeps/."""
    manifest = {
        "sweep_id": sweep_id,
        "model": model,
        "method": method,
        "num_layers": MODEL_CONFIGS[model],
        "rates": rates,
        "runs": run_infos,
        "created_at": datetime.now().isoformat(),
    }

    # Add method-specific options
    if method == "gptq":
        manifest["groupsize"] = groupsize
        manifest["percdamp"] = percdamp
        if gptq_maxq is not None:
            manifest["gptq_maxq"] = gptq_maxq
    elif method == "zsic":
        manifest["qronos"] = qronos
        if zsic_percdamp is not None:
            manifest["zsic_percdamp"] = zsic_percdamp
        if qronos_layer_min is not None:
            manifest["qronos_layer_min"] = qronos_layer_min
        if qronos_layer_max is not None:
            manifest["qronos_layer_max"] = qronos_layer_max
        if qronos_skip_layers:
            manifest["qronos_skip_layers"] = qronos_skip_layers
        if qronos_skip_weights:
            manifest["qronos_skip_weights"] = qronos_skip_weights
        if qronos_skip_qkv_prefix > 0:
            manifest["qronos_skip_qkv_prefix"] = qronos_skip_qkv_prefix
        if qronos_auto_skip_min_diag > 0:
            manifest["qronos_auto_skip_min_diag"] = qronos_auto_skip_min_diag
        if rate_weight_budgets:
            manifest["rate_weight_budgets"] = rate_weight_budgets
        if residual_compensation:
            manifest["residual_compensation"] = residual_compensation
            if rescomp_skip_prefix > 0:
                manifest["rescomp_skip_prefix"] = rescomp_skip_prefix

    # Add diagnostics options if enabled
    if collect_qronos_stats or plot_activation_mse:
        manifest["diagnostics"] = {
            "collect_qronos_stats": collect_qronos_stats,
            "plot_activation_mse": plot_activation_mse,
        }

    # Add Hadamard options if enabled
    if hadamard:
        manifest["hadamard"] = {
            "enabled": True,
            "type": hadamard_type,
            "seed": hadamard_seed,
        }

    # Save to $QUANT_BUCKET/run_root/model/sweeps/
    bucket = _get_bucket_path()
    sweeps_dir = bucket / run_root / model / "sweeps"
    sweeps_dir.mkdir(parents=True, exist_ok=True)

    manifest_path = sweeps_dir / f"{sweep_id}.json"
    with open(manifest_path, "w") as f:
        json.dump(manifest, f, indent=2)

    return manifest_path


def main():
    p = argparse.ArgumentParser(description="Run quantization sweep over multiple rates")
    p.add_argument("--model", required=True, choices=list(MODEL_CONFIGS.keys()),
                   help="Model to quantize")
    p.add_argument("--method", required=True, choices=["zsic", "gptq"],
                   help="Quantization method")
    p.add_argument("--rate_min", type=float, default=1.0,
                   help="Minimum target rate (default: 1.0)")
    p.add_argument("--rate_max", type=float, default=4.0,
                   help="Maximum target rate (default: 4.0)")
    p.add_argument("--rate_step", type=float, default=0.5,
                   help="Rate step (default: 0.5)")
    p.add_argument("--rates", type=str, default=None,
                   help="Explicit list of rates, comma-separated (overrides min/max/step)")
    p.add_argument("--run_root", type=str, default="quant_runs",
                   help="Output directory for runs")
    p.add_argument("--hessian_batch_size", type=int, default=32,
                   help="Batch size for Hessian computation")

    # GPTQ options
    p.add_argument("--groupsize", type=int, default=-1,
                   help="GPTQ group size (-1 = per-channel, default: -1)")
    p.add_argument("--percdamp", type=float, default=0.1,
                   help="GPTQ Hessian damping (default: 0.1)")
    p.add_argument("--gptq_maxq", type=int, default=None,
                   help="GPTQ maxq override (default: compute from target_rate as 2^(rate+1)-1)")
    p.add_argument("--unquant_hessians", action="store_true",
                   help="Use Hessians from unquantized model (avoids error propagation)")

    # ZSIC options
    p.add_argument("--qronos", action="store_true",
                   help="Enable Qronos mode for ZSIC (default: off)")
    p.add_argument("--zsic_percdamp", type=float, default=None,
                   help="ZSIC Hessian damping (default: 0.0001)")
    p.add_argument("--qronos_layer_min", type=int, default=None,
                   help="Only apply Qronos targeting to layers >= this (default: all)")
    p.add_argument("--qronos_layer_max", type=int, default=None,
                   help="Only apply Qronos targeting to layers < this (default: all)")
    p.add_argument("--qronos_skip_layers", type=str, default="",
                   help="Skip Qronos targeting for specific (layer.weight) pairs. E.g., '2.wo,3.wo'")
    p.add_argument("--qronos_skip_weights", type=str, default="",
                   help="Skip Qronos targeting for all layers of specific weight types. E.g., 'wq,wk,wv'")
    p.add_argument("--qronos_skip_qkv_prefix", type=int, default=0,
                   help="Skip Qronos for wq/wk/wv in the first N layers (default: 0 = no skip)")
    p.add_argument("--qronos_auto_skip_min_diag", type=float, default=0.0,
                   help="Auto-skip Qronos if min(diag(Σ_{X,X̂})) < threshold (default: 0 = disabled, recommended: 1e-5)")
    p.add_argument("--rate_weight_budgets", type=str, default="",
                   help="Weight-type budget multipliers. Format: 'wo:1.15,w2:1.15' gives wo/w2 15%% more bits")

    # Diagnostics options
    p.add_argument("--collect_qronos_stats", action="store_true",
                   help="Collect Qronos stats for diagnostics (no Qronos targeting)")
    p.add_argument("--plot_activation_mse", action="store_true",
                   help="Plot activation MSE at end of each run (requires --qronos or --collect_qronos_stats)")
    p.add_argument("--residual_compensation", action="store_true",
                   help="Enable residual stream compensation for wo/w2 layers (automatically enables Qronos mode for wo/w2)")
    p.add_argument("--rescomp_skip_prefix", type=int, default=0,
                   help="Skip residual compensation on the first N layers (0 = apply to all)")

    # Hadamard options
    p.add_argument("--hadamard", action="store_true",
                   help="Enable Hadamard rotation (default: off)")
    p.add_argument("--hadamard_type", type=str, default="row",
                   choices=["none", "row", "column", "row_column"],
                   help="Type of Hadamard transform (default: row)")
    p.add_argument("--hadamard_seed", type=int, default=0,
                   help="Hadamard random seed (default: 0)")

    p.add_argument("--gpus", type=str, default=None,
                   help="Comma-separated list of GPU IDs to use (default: auto-detect free GPUs)")
    p.add_argument("--no_resume", action="store_true",
                   help="Force fresh runs (don't resume from existing artifacts)")

    args = p.parse_args()

    # Build rate list
    if args.rates:
        rates = [float(r.strip()) for r in args.rates.split(",")]
    else:
        rates = []
        r = args.rate_min
        while r <= args.rate_max + 1e-9:
            rates.append(round(r, 2))
            r += args.rate_step

    # Generate sweep ID
    sweep_id = generate_sweep_id(args.method, qronos=args.qronos, residual_compensation=args.residual_compensation, hadamard=args.hadamard, hadamard_type=args.hadamard_type)

    print(f"Sweep config:")
    print(f"  Sweep ID: {sweep_id}")
    print(f"  Model: {args.model} ({MODEL_CONFIGS[args.model]} layers)")
    print(f"  Method: {args.method}")
    print(f"  Rates: {rates}")
    if args.method == "gptq":
        print(f"  Groupsize: {args.groupsize}")
        print(f"  Percdamp: {args.percdamp}")
        if args.gptq_maxq is not None:
            print(f"  Maxq override: {args.gptq_maxq}")
        print(f"  Unquant Hessians: {args.unquant_hessians}")
    if args.method == "zsic":
        print(f"  Qronos: {args.qronos}")
        if args.zsic_percdamp is not None:
            print(f"  ZSIC Percdamp: {args.zsic_percdamp}")
        if args.qronos and (args.qronos_layer_min is not None or args.qronos_layer_max is not None):
            print(f"  Qronos layer range: [{args.qronos_layer_min}, {args.qronos_layer_max})")
        if args.qronos and args.qronos_skip_layers:
            print(f"  Qronos skip layers: {args.qronos_skip_layers}")
        if args.qronos and args.qronos_skip_weights:
            print(f"  Qronos skip weights: {args.qronos_skip_weights}")
        if args.qronos and args.qronos_skip_qkv_prefix > 0:
            print(f"  Qronos skip QKV prefix: {args.qronos_skip_qkv_prefix} layers")
        if args.qronos and args.qronos_auto_skip_min_diag > 0:
            print(f"  Qronos auto-skip min_diag: {args.qronos_auto_skip_min_diag:.2e}")
        if args.rate_weight_budgets:
            print(f"  Weight budgets: {args.rate_weight_budgets}")
    if args.hadamard:
        print(f"  Hadamard: {args.hadamard_type} (seed={args.hadamard_seed})")
    if args.collect_qronos_stats or args.plot_activation_mse:
        print(f"  Collect Qronos stats: {args.collect_qronos_stats}")
        print(f"  Plot activation MSE: {args.plot_activation_mse}")
    if args.residual_compensation:
        skip_str = f" (skip first {args.rescomp_skip_prefix} layers)" if args.rescomp_skip_prefix > 0 else ""
        print(f"  Residual compensation: {args.residual_compensation}{skip_str}")
    print(f"  Resume: {not args.no_resume}")
    print(f"  Output: {args.run_root}")

    # Build tasks
    tasks, run_infos = build_quant_tasks(
        model=args.model,
        method=args.method,
        rates=rates,
        run_root=args.run_root,
        hessian_batch_size=args.hessian_batch_size,
        groupsize=args.groupsize,
        percdamp=args.percdamp,
        gptq_maxq=args.gptq_maxq,
        unquant_hessians=args.unquant_hessians,
        qronos=args.qronos,
        zsic_percdamp=args.zsic_percdamp,
        qronos_layer_min=args.qronos_layer_min,
        qronos_layer_max=args.qronos_layer_max,
        qronos_skip_layers=args.qronos_skip_layers,
        qronos_skip_weights=args.qronos_skip_weights,
        qronos_skip_qkv_prefix=args.qronos_skip_qkv_prefix,
        qronos_auto_skip_min_diag=args.qronos_auto_skip_min_diag,
        rate_weight_budgets=args.rate_weight_budgets,
        collect_qronos_stats=args.collect_qronos_stats,
        plot_activation_mse=args.plot_activation_mse,
        residual_compensation=args.residual_compensation,
        rescomp_skip_prefix=args.rescomp_skip_prefix,
        hadamard=args.hadamard,
        hadamard_type=args.hadamard_type,
        hadamard_seed=args.hadamard_seed,
        resume=not args.no_resume,
    )

    print(f"\nBuilt {len(tasks)} quantization tasks")

    # Save sweep manifest before running (so eval can find it even if interrupted)
    manifest_path = save_sweep_manifest(
        sweep_id=sweep_id,
        model=args.model,
        method=args.method,
        rates=rates,
        run_infos=run_infos,
        run_root=args.run_root,
        groupsize=args.groupsize,
        percdamp=args.percdamp,
        gptq_maxq=args.gptq_maxq,
        qronos=args.qronos,
        zsic_percdamp=args.zsic_percdamp,
        qronos_layer_min=args.qronos_layer_min,
        qronos_layer_max=args.qronos_layer_max,
        qronos_skip_layers=args.qronos_skip_layers,
        qronos_skip_weights=args.qronos_skip_weights,
        qronos_skip_qkv_prefix=args.qronos_skip_qkv_prefix,
        qronos_auto_skip_min_diag=args.qronos_auto_skip_min_diag,
        rate_weight_budgets=args.rate_weight_budgets,
        collect_qronos_stats=args.collect_qronos_stats,
        plot_activation_mse=args.plot_activation_mse,
        residual_compensation=args.residual_compensation,
        rescomp_skip_prefix=args.rescomp_skip_prefix,
        hadamard=args.hadamard,
        hadamard_type=args.hadamard_type,
        hadamard_seed=args.hadamard_seed,
    )
    print(f"Saved sweep manifest: {manifest_path}")

    # Parse GPU list
    gpu_list = None
    if args.gpus:
        gpu_list = [int(g.strip()) for g in args.gpus.split(",")]

    # Run tasks
    run_tasks(tasks, gpu_list=gpu_list)

    print(f"\nSweep complete. To evaluate and plot:")
    print(f"  python -m scripts.run_eval_sweep --sweep {manifest_path} --eval --plot")


if __name__ == "__main__":
    main()
