#!/usr/bin/env python3
"""Debug script to inspect a single artifact and test dequantization."""

import sys
import torch
from pathlib import Path

# Add parent to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from quant_layerwise.storage import LayerArtifact, RunManifest
from quant_layerwise.methods.zsic import dequantize_zsic
from quant_layerwise.hadamard import BlockRandomHadamard, HadamardType, inverse_hadamard_weight


def debug_artifact(artifact_path: str, compare_model: bool = False):
    print(f"Loading artifact: {artifact_path}")
    art = LayerArtifact.load(artifact_path, map_location="cpu")

    print(f"\n=== Artifact Info ===")
    print(f"  method: {art.method}")
    print(f"  module_name: {art.module_name}")
    print(f"  shape: {art.shape}")
    print(f"  hadamard: {art.hadamard}")

    payload = art.payload
    print(f"\n=== Payload Keys ===")
    for k, v in payload.items():
        if isinstance(v, torch.Tensor):
            print(f"  {k}: tensor shape={v.shape}, dtype={v.dtype}, device={v.device}")
            print(f"       min={v.min().item():.6f}, max={v.max().item():.6f}")
        else:
            print(f"  {k}: {v}")

    # Dequantize
    print(f"\n=== Dequantization ===")
    Z = payload["Z"]
    alpha = payload["alpha"]
    alpha_base = payload.get("alpha_base", None)
    zero_point = payload.get("zero_point", None)
    apply_tgamma = bool(payload.get("apply_tgamma", False))
    t_vec = payload.get("t_vec", None)
    g_vec = payload.get("g_vec", None)

    print(f"  apply_tgamma: {apply_tgamma}")

    W_hat = dequantize_zsic(
        Z, alpha,
        alpha_base=alpha_base,
        zero_point=zero_point,
        apply_tgamma=apply_tgamma,
        t_vec=t_vec,
        g_vec=g_vec,
        dtype=torch.float32,
    )

    print(f"  W_hat after ZSIC dequant:")
    print(f"    shape: {W_hat.shape}")
    print(f"    min: {W_hat.min().item():.6f}")
    print(f"    max: {W_hat.max().item():.6f}")
    print(f"    mean: {W_hat.mean().item():.6f}")
    print(f"    std: {W_hat.std().item():.6f}")
    print(f"    has_nan: {W_hat.isnan().any().item()}")
    print(f"    has_inf: {W_hat.isinf().any().item()}")

    # Apply inverse Hadamard if enabled
    had_cfg = art.hadamard or {"enabled": False}
    if had_cfg.get("enabled", False):
        type_str = had_cfg.get("type", "row").lower()
        type_map = {
            "none": HadamardType.NONE,
            "row": HadamardType.ROW,
            "column": HadamardType.COLUMN,
            "row_column": HadamardType.ROW_COLUMN,
        }
        hadamard_type = type_map.get(type_str, HadamardType.ROW)

        print(f"\n=== Hadamard Inverse ===")
        print(f"  hadamard_type: {hadamard_type}")
        print(f"  seed: {had_cfg.get('seed', 0)}")

        if hadamard_type != HadamardType.NONE:
            out_dim, in_dim = W_hat.shape
            seed = int(had_cfg.get("seed", 0))

            had_row = None
            had_col = None

            if hadamard_type in (HadamardType.COLUMN, HadamardType.ROW_COLUMN):
                had_row = BlockRandomHadamard(out_dim, seed=seed, device="cpu", dtype=W_hat.dtype)

            if hadamard_type in (HadamardType.ROW, HadamardType.ROW_COLUMN):
                had_col = BlockRandomHadamard(in_dim, seed=seed, device="cpu", dtype=W_hat.dtype)

            print(f"  had_row created: {had_row is not None}")
            print(f"  had_col created: {had_col is not None}")

            W_before = W_hat.clone()
            W_hat = inverse_hadamard_weight(W_hat, hadamard_type, had_row=had_row, had_col=had_col)

            print(f"\n  W_hat after inverse Hadamard:")
            print(f"    min: {W_hat.min().item():.6f}")
            print(f"    max: {W_hat.max().item():.6f}")
            print(f"    mean: {W_hat.mean().item():.6f}")
            print(f"    std: {W_hat.std().item():.6f}")
            print(f"    has_nan: {W_hat.isnan().any().item()}")
            print(f"    has_inf: {W_hat.isinf().any().item()}")

            # Check if values changed
            diff = (W_hat - W_before).abs()
            print(f"\n  Difference from before:")
            print(f"    max_diff: {diff.max().item():.6f}")
            print(f"    mean_diff: {diff.mean().item():.6f}")
    else:
        print("\n  Hadamard not enabled, skipping inverse")

    print("\n=== Done ===")
    return W_hat


def compare_with_original(run_dir: str, artifact_path: str):
    """Compare dequantized weights with original model weights."""
    from quant_layerwise.pipeline import load_model_and_tokenizer, ensure_single_process_distributed

    # Initialize distributed for single process
    ensure_single_process_distributed(local_rank=0, master_port=29599)

    run_dir = Path(run_dir)
    manifest = RunManifest.load(run_dir / "manifest.json")

    print(f"\n=== Loading Original Model: {manifest.model_name} ===")
    model, _ = load_model_and_tokenizer(manifest.model_name, local_rank=0)

    art = LayerArtifact.load(artifact_path, map_location="cpu")
    module_name = art.module_name

    # Get original weight
    mods = dict(model.named_modules())
    if module_name not in mods:
        print(f"ERROR: Module {module_name} not found in model")
        return

    orig_weight = mods[module_name].weight.data.float().cpu()
    print(f"\n=== Original Weight: {module_name} ===")
    print(f"  shape: {orig_weight.shape}")
    print(f"  min: {orig_weight.min().item():.6f}")
    print(f"  max: {orig_weight.max().item():.6f}")
    print(f"  mean: {orig_weight.mean().item():.6f}")
    print(f"  std: {orig_weight.std().item():.6f}")

    # Get dequantized weight
    W_hat = debug_artifact(artifact_path)

    # Compare
    print(f"\n=== Comparison ===")
    diff = (W_hat - orig_weight).abs()
    print(f"  max_diff: {diff.max().item():.6f}")
    print(f"  mean_diff: {diff.mean().item():.6f}")
    print(f"  relative_error: {(diff / (orig_weight.abs() + 1e-8)).mean().item():.6f}")

    # Check cosine similarity
    cos_sim = torch.nn.functional.cosine_similarity(
        W_hat.flatten().unsqueeze(0),
        orig_weight.flatten().unsqueeze(0)
    ).item()
    print(f"  cosine_similarity: {cos_sim:.6f}")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        print("Usage: python debug_artifact.py <artifact_path> [--compare <run_dir>]")
        print("Example: python debug_artifact.py quant_runs/model.zsic.r3.00/layers/layers.0.attention.wq.zsic.pt")
        print("Example with comparison: python debug_artifact.py quant_runs/.../layers/layers.0.attention.wq.zsic.pt --compare quant_runs/.../")
        sys.exit(1)

    if "--compare" in sys.argv:
        idx = sys.argv.index("--compare")
        run_dir = sys.argv[idx + 1]
        compare_with_original(run_dir, sys.argv[1])
    else:
        debug_artifact(sys.argv[1])
