"""Compute per-layer activation MSE between unquantized and quantized models.

Loads quantized model artifacts and computes ||X - X̂|| / ||X|| for each layer
by running inference. No pre-computed qronos_stats needed.

Usage:
    # Compare WaterSICR (with residual compensation) vs WaterSIC (without)
    python scripts/compute_activation_mse.py \
        --run_a /path/to/rescomp_run \
        --run_b /path/to/no_rescomp_run \
        --label_a "WaterSICR" \
        --label_b "WaterSIC" \
        --output watersicr_vs_watersic.png \
        --init_dist
"""

import argparse
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from quant_layerwise.data import get_wikitext2, split_dataset, take_nseq
from quant_layerwise.partial_model import load_and_apply_manifest
from quant_layerwise.pipeline import ensure_single_process_distributed, load_model_and_tokenizer
from quant_layerwise.storage import RunManifest


def get_layer_modules(model) -> Dict[str, torch.nn.Module]:
    """Get all weight modules we want to track activations for."""
    modules = {}
    for layer_idx, layer in enumerate(model.layers):
        modules[f"layers.{layer_idx}.attention.wq"] = layer.attention.wq
        modules[f"layers.{layer_idx}.attention.wk"] = layer.attention.wk
        modules[f"layers.{layer_idx}.attention.wv"] = layer.attention.wv
        modules[f"layers.{layer_idx}.attention.wo"] = layer.attention.wo
        modules[f"layers.{layer_idx}.feed_forward.w1"] = layer.feed_forward.w1
        modules[f"layers.{layer_idx}.feed_forward.w2"] = layer.feed_forward.w2
        modules[f"layers.{layer_idx}.feed_forward.w3"] = layer.feed_forward.w3
    return modules


class ActivationCapture:
    """Hook to capture input activations for a module."""

    def __init__(self):
        self.activations: List[torch.Tensor] = []
        self.handle = None

    def hook(self, module, input, output):
        x = input[0].detach()
        self.activations.append(x.cpu())

    def register(self, module: torch.nn.Module):
        self.handle = module.register_forward_hook(self.hook)

    def remove(self):
        if self.handle is not None:
            self.handle.remove()
            self.handle = None

    def get_concatenated(self) -> torch.Tensor:
        if not self.activations:
            return None
        return torch.cat(self.activations, dim=0)

    def clear(self):
        self.activations = []


class OnlineMetricsCapture:
    """Compute MSE and cosine similarity metrics on-the-fly without storing activations."""

    def __init__(self):
        self.handle = None
        # Running stats for MSE: sum((x_ref - x_q)^2), sum(x_ref^2)
        self.mse_num = 0.0
        self.ref_norm_sq = 0.0
        # Running stats for cosine: sum(x_ref * x_q), sum(x_ref^2), sum(x_q^2)
        self.dot_product = 0.0
        self.quant_norm_sq = 0.0
        self.n_elements = 0

    def set_reference(self, x_ref: torch.Tensor, device: torch.device):
        """Store reference activation for comparison (pre-moved to device)."""
        self._x_ref = x_ref.flatten().double().to(device)
        self._ref_idx = 0

    def hook(self, module, input, output):
        x_q = input[0].detach().flatten().double()
        # Get corresponding slice of reference (already on same device)
        n = x_q.numel()
        x_ref = self._x_ref[self._ref_idx:self._ref_idx + n]
        self._ref_idx += n

        # Update running stats (on GPU, then move scalars)
        self.mse_num += torch.sum((x_ref - x_q) ** 2).item()
        self.ref_norm_sq += torch.sum(x_ref ** 2).item()
        self.dot_product += torch.sum(x_ref * x_q).item()
        self.quant_norm_sq += torch.sum(x_q ** 2).item()
        self.n_elements += n

    def register(self, module: torch.nn.Module):
        self.handle = module.register_forward_hook(self.hook)

    def remove(self):
        if self.handle is not None:
            self.handle.remove()
            self.handle = None

    def get_metrics(self) -> Tuple[float, float]:
        """Return (relative_mse, cosine_similarity)."""
        if self.ref_norm_sq > 0:
            rel_mse = np.sqrt(self.mse_num / self.ref_norm_sq)
        else:
            rel_mse = 0.0

        norm_ref = np.sqrt(self.ref_norm_sq)
        norm_quant = np.sqrt(self.quant_norm_sq)
        if norm_ref > 0 and norm_quant > 0:
            cos_sim = self.dot_product / (norm_ref * norm_quant)
        else:
            cos_sim = 1.0

        return rel_mse, cos_sim


@torch.no_grad()
def capture_activations(
    model,
    eval_tokens: torch.Tensor,
    module_names: List[str],
    modules: Dict[str, torch.nn.Module],
    batch_size: int = 32,
) -> Dict[str, torch.Tensor]:
    """Run model and capture activations at specified modules."""
    captures = {}
    for name in module_names:
        if name in modules:
            captures[name] = ActivationCapture()
            captures[name].register(modules[name])

    model.eval()
    device = next(model.parameters()).device

    # Resize KV cache if needed (model may have been initialized with smaller batch size)
    if hasattr(model, 'resize_kv_caches'):
        model.resize_kv_caches(batch_size)
    elif hasattr(model, 'layers') and hasattr(model.layers[0], 'attention'):
        # Manual resize for Llama-style models
        for layer in model.layers:
            attn = layer.attention
            if hasattr(attn, 'cache_k') and attn.cache_k.shape[0] < batch_size:
                old_shape = attn.cache_k.shape
                new_shape = (batch_size, old_shape[1], old_shape[2], old_shape[3])
                attn.cache_k = torch.zeros(new_shape, device=attn.cache_k.device, dtype=attn.cache_k.dtype)
                attn.cache_v = torch.zeros(new_shape, device=attn.cache_v.device, dtype=attn.cache_v.dtype)

    nsamples = eval_tokens.shape[0]
    for i in range(0, nsamples, batch_size):
        batch = eval_tokens[i:i+batch_size].to(device)
        _ = model(batch, start_pos=0)

    results = {}
    for name, capture in captures.items():
        results[name] = capture.get_concatenated()
        capture.remove()

    return results


def compute_relative_mse(x_ref: torch.Tensor, x_quant: torch.Tensor) -> float:
    """Compute sqrt(||X - X̂||²) / sqrt(||X||²) = ||X - X̂|| / ||X||."""
    x_ref = x_ref.reshape(-1).double()
    x_quant = x_quant.reshape(-1).double()

    mse = torch.sum((x_ref - x_quant) ** 2).item()
    ref_norm_sq = torch.sum(x_ref ** 2).item()

    if ref_norm_sq > 0:
        return np.sqrt(mse / ref_norm_sq)
    return 0.0


def compute_cosine_similarity(x_ref: torch.Tensor, x_quant: torch.Tensor) -> float:
    """Compute cosine similarity: (X · X̂) / (||X|| ||X̂||)."""
    x_ref = x_ref.reshape(-1).double()
    x_quant = x_quant.reshape(-1).double()

    dot_product = torch.sum(x_ref * x_quant).item()
    norm_ref = torch.sqrt(torch.sum(x_ref ** 2)).item()
    norm_quant = torch.sqrt(torch.sum(x_quant ** 2)).item()

    if norm_ref > 0 and norm_quant > 0:
        return dot_product / (norm_ref * norm_quant)
    return 1.0  # If both zero, consider them identical


def parse_layer_name(name: str) -> Tuple[int, str, str]:
    parts = name.split(".")
    if len(parts) >= 4 and parts[0] == "layers":
        layer_id = int(parts[1])
        block_type = parts[2]
        weight_type = parts[3]
        return layer_id, block_type, weight_type
    return -1, "", name


def get_sort_key(name: str) -> Tuple[int, int, int]:
    layer_id, block_type, weight_type = parse_layer_name(name)
    block_order = 0 if block_type == "attention" else 1
    weight_order_map = {"wq": 0, "wk": 1, "wv": 2, "wo": 3, "w1": 0, "w2": 1, "w3": 2}
    weight_order = weight_order_map.get(weight_type, 99)
    return (layer_id, block_order, weight_order)


def get_merged_weight_type(weight_type: str) -> str:
    if weight_type in ("wq", "wk", "wv"):
        return "qkv"
    elif weight_type in ("w1", "w3"):
        return "w1w3"
    return weight_type


def get_weight_type_color(weight_type: str) -> str:
    color_map = {"qkv": "tab:blue", "wo": "tab:cyan", "w1w3": "tab:orange", "w2": "tab:green"}
    return color_map.get(weight_type, "tab:gray")


def get_cache_path(run_dir: Path, nsamples: int, seqlen: int) -> Path:
    """Get cache file path for a run's metrics."""
    return run_dir / f"activation_metrics_cache_n{nsamples}_s{seqlen}.pt"


def load_cached_metrics(run_dir: Path, nsamples: int, seqlen: int) -> Optional[Tuple[Dict, Dict]]:
    """Load cached metrics if available."""
    cache_path = get_cache_path(run_dir, nsamples, seqlen)
    if cache_path.exists():
        try:
            data = torch.load(cache_path, map_location="cpu", weights_only=False)
            print(f"  Loaded cached metrics from {cache_path.name}")
            return data["mse"], data["cos_sim"]
        except Exception as e:
            print(f"  Cache load failed: {e}")
    return None


def save_cached_metrics(run_dir: Path, nsamples: int, seqlen: int, mse: Dict, cos_sim: Dict):
    """Save metrics to cache."""
    cache_path = get_cache_path(run_dir, nsamples, seqlen)
    torch.save({"mse": mse, "cos_sim": cos_sim}, cache_path)
    print(f"  Saved metrics cache to {cache_path.name}")


def compute_metrics_for_run(
    model_ref,
    run_dir: Path,
    eval_tokens: torch.Tensor,
    local_rank: int = 0,
    batch_size: int = 32,
    nsamples: int = 64,
    seqlen: int = 2048,
    use_cache: bool = True,
) -> Tuple[Dict[str, float], Dict[str, float]]:
    """Compute per-layer metrics for a quantized run vs reference model.

    Processes batch-by-batch to avoid storing all activations.
    """
    # Check cache first
    if use_cache:
        cached = load_cached_metrics(run_dir, nsamples, seqlen)
        if cached is not None:
            return cached

    manifest = RunManifest.load(run_dir / "manifest.json")

    # Load and quantize model
    print(f"  Loading quantized model from {run_dir.name}...")
    model_q, _ = load_model_and_tokenizer(manifest.model_name, local_rank=local_rank)
    load_and_apply_manifest(model_q, run_dir)

    modules_ref = get_layer_modules(model_ref)
    modules_q = get_layer_modules(model_q)
    module_names = list(modules_ref.keys())

    model_q.eval()
    model_ref.eval()
    device = next(model_q.parameters()).device

    # Resize KV caches for both models
    for model in [model_ref, model_q]:
        if hasattr(model, 'resize_kv_caches'):
            model.resize_kv_caches(batch_size)
        elif hasattr(model, 'layers') and hasattr(model.layers[0], 'attention'):
            for layer in model.layers:
                attn = layer.attention
                if hasattr(attn, 'cache_k') and attn.cache_k.shape[0] < batch_size:
                    old_shape = attn.cache_k.shape
                    new_shape = (batch_size, old_shape[1], old_shape[2], old_shape[3])
                    attn.cache_k = torch.zeros(new_shape, device=attn.cache_k.device, dtype=attn.cache_k.dtype)
                    attn.cache_v = torch.zeros(new_shape, device=attn.cache_v.device, dtype=attn.cache_v.dtype)

    # Accumulate metrics across batches
    running_stats = {name: {"mse_num": 0.0, "ref_norm_sq": 0.0, "dot": 0.0, "q_norm_sq": 0.0}
                     for name in module_names}

    print(f"  Computing metrics batch-by-batch ({len(module_names)} modules)...")
    n_total = eval_tokens.shape[0]
    n_batches = (n_total + batch_size - 1) // batch_size

    with torch.no_grad():
        for batch_idx in range(0, n_total, batch_size):
            batch = eval_tokens[batch_idx:batch_idx+batch_size].to(device)

            # Capture ref activations for this batch
            ref_caps = {name: ActivationCapture() for name in module_names}
            for name in module_names:
                ref_caps[name].register(modules_ref[name])
            _ = model_ref(batch, start_pos=0)
            for name in module_names:
                ref_caps[name].remove()

            # Clear intermediate tensors before quant run
            torch.cuda.empty_cache()

            # Capture quant activations for this batch
            q_caps = {name: ActivationCapture() for name in module_names}
            for name in module_names:
                q_caps[name].register(modules_q[name])
            _ = model_q(batch, start_pos=0)
            for name in module_names:
                q_caps[name].remove()

            # Compute metrics for this batch and accumulate
            for name in module_names:
                x_ref = ref_caps[name].get_concatenated()
                x_q = q_caps[name].get_concatenated()
                if x_ref is not None and x_q is not None:
                    x_ref = x_ref.flatten().double().to(device)
                    x_q = x_q.flatten().double().to(device)
                    running_stats[name]["mse_num"] += torch.sum((x_ref - x_q) ** 2).item()
                    running_stats[name]["ref_norm_sq"] += torch.sum(x_ref ** 2).item()
                    running_stats[name]["dot"] += torch.sum(x_ref * x_q).item()
                    running_stats[name]["q_norm_sq"] += torch.sum(x_q ** 2).item()

            if (batch_idx // batch_size + 1) % 2 == 0:
                print(f"    Batch {batch_idx // batch_size + 1}/{n_batches}")

    # Compute final metrics
    mse_results = {}
    cos_sim_results = {}
    for name, stats in running_stats.items():
        if stats["ref_norm_sq"] > 0:
            mse_results[name] = np.sqrt(stats["mse_num"] / stats["ref_norm_sq"])
        else:
            mse_results[name] = 0.0

        norm_ref = np.sqrt(stats["ref_norm_sq"])
        norm_q = np.sqrt(stats["q_norm_sq"])
        if norm_ref > 0 and norm_q > 0:
            cos_sim_results[name] = stats["dot"] / (norm_ref * norm_q)
        else:
            cos_sim_results[name] = 1.0

    # Free quantized model
    del model_q
    torch.cuda.empty_cache()

    # Save to cache
    if use_cache:
        save_cached_metrics(run_dir, nsamples, seqlen, mse_results, cos_sim_results)

    return mse_results, cos_sim_results


def plot_comparison(
    run_a: Path,
    run_b: Path,
    output_path: Path = None,
    label_a: str = None,
    label_b: str = None,
    title: str = None,
    figsize: Tuple[float, float] = (22, 7),
    merge_inputs: bool = True,
    nsamples: int = 64,
    batch_size: int = 32,
    seqlen: int = 2048,
    local_rank: int = 0,
    init_dist: bool = False,
    master_port_base: int = 29500,
    use_cache: bool = True,
):
    """Compare per-layer activation MSE between two quantized runs."""

    if init_dist:
        ensure_single_process_distributed(local_rank=local_rank, master_port=master_port_base)

    manifest_a = RunManifest.load(run_a / "manifest.json")
    manifest_b = RunManifest.load(run_b / "manifest.json")

    if manifest_a.model_name != manifest_b.model_name:
        raise ValueError(f"Different models: {manifest_a.model_name} vs {manifest_b.model_name}")

    model_name = manifest_a.model_name
    if label_a is None:
        label_a = run_a.name
    if label_b is None:
        label_b = run_b.name

    print(f"Computing activation MSE comparison:")
    print(f"  Run A ({label_a}): {run_a}")
    print(f"  Run B ({label_b}): {run_b}")
    print(f"  Model: {model_name}")

    # Check if both runs have cached metrics (skip loading ref model if so)
    cached_a = load_cached_metrics(run_a, nsamples, seqlen) if use_cache else None
    cached_b = load_cached_metrics(run_b, nsamples, seqlen) if use_cache else None

    if cached_a is not None and cached_b is not None:
        print("Both runs have cached metrics, skipping model loading.")
        mse_a, cos_sim_a = cached_a
        mse_b, cos_sim_b = cached_b
    else:
        # Need to compute at least one - load reference model
        print(f"\nLoading reference (unquantized) model...")
        model_ref, tokenizer = load_model_and_tokenizer(model_name, local_rank=local_rank)

        # Prepare eval data
        print(f"Preparing evaluation data (nsamples={nsamples}, seqlen={seqlen})...")
        eval_tokens = split_dataset(get_wikitext2(tokenizer, split="test"), seqlen)
        eval_tokens = take_nseq(eval_tokens, nsamples)
        print(f"Using {eval_tokens.shape[0]} samples")

        # Compute metrics for run A
        if cached_a is not None:
            mse_a, cos_sim_a = cached_a
        else:
            print(f"\nComputing metrics for {label_a}...")
            mse_a, cos_sim_a = compute_metrics_for_run(
                model_ref, run_a, eval_tokens, local_rank, batch_size, nsamples, seqlen, use_cache
            )

        # Compute metrics for run B
        if cached_b is not None:
            mse_b, cos_sim_b = cached_b
        else:
            print(f"\nComputing metrics for {label_b}...")
            mse_b, cos_sim_b = compute_metrics_for_run(
                model_ref, run_b, eval_tokens, local_rank, batch_size, nsamples, seqlen, use_cache
            )

        # Free reference model
        del model_ref
        torch.cuda.empty_cache()

    # Build results with both metrics
    results_a = []
    results_b = []
    for name in sorted(mse_a.keys(), key=get_sort_key):
        if name in mse_b and name in cos_sim_a and name in cos_sim_b:
            layer_id, block_type, weight_type = parse_layer_name(name)
            results_a.append({
                "name": name,
                "rel_mse": mse_a[name],
                "cos_sim": cos_sim_a[name],
                "weight_type": weight_type,
                "merged_type": get_merged_weight_type(weight_type),
                "layer_id": layer_id,
            })
            results_b.append({
                "name": name,
                "rel_mse": mse_b[name],
                "cos_sim": cos_sim_b[name],
                "weight_type": weight_type,
                "merged_type": get_merged_weight_type(weight_type),
                "layer_id": layer_id,
            })

    # Merge if requested
    if merge_inputs:
        def merge_results(results):
            grouped = defaultdict(list)
            for r in results:
                key = (r["layer_id"], r["merged_type"])
                grouped[key].append(r)

            merged = []
            for (layer_id, merged_type), items in grouped.items():
                avg_mse = np.mean([r["rel_mse"] for r in items])
                avg_cos_sim = np.mean([r["cos_sim"] for r in items])
                merged.append({
                    "short_name": f"L{layer_id}_{merged_type}",
                    "rel_mse": avg_mse,
                    "cos_sim": avg_cos_sim,
                    "weight_type": merged_type,
                    "layer_id": layer_id,
                    "sort_key": (layer_id, 0 if merged_type == "qkv" else 1 if merged_type == "wo" else 2 if merged_type == "w1w3" else 3),
                })
            merged.sort(key=lambda x: x["sort_key"])
            return merged

        results_a = merge_results(results_a)
        results_b = merge_results(results_b)
    else:
        for r in results_a:
            r["short_name"] = f"L{r['layer_id']}_{r['weight_type']}"
        for r in results_b:
            r["short_name"] = f"L{r['layer_id']}_{r['weight_type']}"

    # Extract plotting data
    names = [r["short_name"] for r in results_a]
    rel_mses_a = [r["rel_mse"] * 100 for r in results_a]
    rel_mses_b = [r["rel_mse"] * 100 for r in results_b]
    # Cosine distance = 1 - cos_sim (so lower is better, like MSE)
    cos_dist_a = [(1 - r["cos_sim"]) * 100 for r in results_a]
    cos_dist_b = [(1 - r["cos_sim"]) * 100 for r in results_b]
    weight_types = [r["weight_type"] for r in results_a]

    # Extract rate from run directory name (e.g., "model.zsic.rescomp.r2.00")
    import re
    rate_match = re.search(r'\.r(\d+\.\d+)', run_a.name)
    rate_str = f"Rate={rate_match.group(1)}" if rate_match else ""

    from matplotlib.lines import Line2D
    x = np.arange(len(names))
    weight_type_order = ["qkv", "wo", "w1w3", "w2"]

    def make_comparison_plot(values_a, values_b, ylabel, plot_title, save_path):
        fig, ax = plt.subplots(1, 1, figsize=figsize, facecolor='white')
        ax.set_facecolor('white')

        # Plot A (filled) and B (hollow) - side by side
        for wt in weight_type_order:
            indices = [i for i, w in enumerate(weight_types) if w == wt]
            if indices:
                color = get_weight_type_color(wt)
                ax.scatter([x[i] - 0.15 for i in indices], [values_a[i] for i in indices],
                           c=color, s=60, alpha=0.9, edgecolors='none', marker='o', zorder=3)
                ax.scatter([x[i] + 0.15 for i in indices], [values_b[i] for i in indices],
                           c='none', s=60, alpha=0.9, edgecolors=color, linewidths=2, marker='o', zorder=3)

        # Connect A and B points
        for i in range(len(names)):
            ax.plot([x[i] - 0.15, x[i] + 0.15], [values_a[i], values_b[i]],
                    color='gray', alpha=0.3, linewidth=0.8, zorder=1)

        ax.grid(True, which="major", axis="both", alpha=0.3, linewidth=1, color='lightgray')
        ax.set_axisbelow(True)
        ax.set_ylabel(ylabel, fontsize=13)
        ax.set_xlabel("Layer", fontsize=13)
        ax.set_title(plot_title, fontsize=14)
        ax.set_xticks(x)
        ax.set_xticklabels(names, rotation=45, ha="right", fontsize=10)

        legend_elements = [
            Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=8, label=f"{label_a} (filled)"),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='none', markeredgecolor='gray',
                   markeredgewidth=2, markersize=8, label=f"{label_b} (hollow)"),
            Line2D([0], [0], marker='', color='none', label=''),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='tab:blue', markersize=8, label='qkv'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='tab:cyan', markersize=8, label='wo'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='tab:orange', markersize=8, label='w1w3'),
            Line2D([0], [0], marker='o', color='w', markerfacecolor='tab:green', markersize=8, label='w2'),
        ]
        ax.legend(handles=legend_elements, loc="upper right", frameon=True, fontsize=10, ncol=2)
        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=150, bbox_inches='tight', facecolor='white')
            print(f"Saved: {save_path}")
        else:
            plt.show()
        plt.close(fig)

    # Plot 1: Relative MSE
    mse_title = title or f"Relative MSE: {label_a} vs {label_b}\n{model_name} {rate_str}"
    mse_path = output_path if output_path else None
    make_comparison_plot(rel_mses_a, rel_mses_b, "Relative Activation MSE (%)", mse_title, mse_path)

    # Plot 2: Cosine Distance
    cos_title = f"Cosine Distance: {label_a} vs {label_b}\n{model_name} {rate_str}"
    cos_path = output_path.with_stem(output_path.stem + "_cosine") if output_path else None
    make_comparison_plot(cos_dist_a, cos_dist_b, "Cosine Distance (1 - cos_sim) %", cos_title, cos_path)

    # Summary
    print(f"\nSummary (Relative MSE %):")
    print(f"  {label_a}: mean={np.mean(rel_mses_a):.2f}%, max={max(rel_mses_a):.2f}%")
    print(f"  {label_b}: mean={np.mean(rel_mses_b):.2f}%, max={max(rel_mses_b):.2f}%")
    diffs_mse = [b - a for a, b in zip(rel_mses_a, rel_mses_b)]
    print(f"  Diff ({label_b} - {label_a}): mean={np.mean(diffs_mse):.2f}%")

    print(f"\nSummary (Cosine Distance %):")
    print(f"  {label_a}: mean={np.mean(cos_dist_a):.4f}%, max={max(cos_dist_a):.4f}%")
    print(f"  {label_b}: mean={np.mean(cos_dist_b):.4f}%, max={max(cos_dist_b):.4f}%")
    diffs_cos = [b - a for a, b in zip(cos_dist_a, cos_dist_b)]
    print(f"  Diff ({label_b} - {label_a}): mean={np.mean(diffs_cos):.4f}%")

    # Save JSON with both metrics
    json_path = output_path.with_suffix('.json') if output_path else run_a / "activation_metrics_compare.json"
    data = {
        "run_a": str(run_a), "run_b": str(run_b),
        "label_a": label_a, "label_b": label_b,
        "nsamples": nsamples, "seqlen": seqlen,
        "layers": [{
            "name": names[i],
            "rel_mse_a_pct": rel_mses_a[i], "rel_mse_b_pct": rel_mses_b[i],
            "cos_dist_a_pct": cos_dist_a[i], "cos_dist_b_pct": cos_dist_b[i],
        } for i in range(len(names))],
    }
    with open(json_path, "w") as f:
        json.dump(data, f, indent=2)
    print(f"Saved JSON: {json_path}")


def main():
    parser = argparse.ArgumentParser(description="Compute and compare per-layer activation MSE")
    parser.add_argument("--run_a", type=str, required=True)
    parser.add_argument("--run_b", type=str, required=True)
    parser.add_argument("--label_a", type=str, default=None)
    parser.add_argument("--label_b", type=str, default=None)
    parser.add_argument("--output", type=str, default=None)
    parser.add_argument("--title", type=str, default=None)
    parser.add_argument("--no_merge", action="store_true")
    parser.add_argument("--figsize", type=str, default="22,7")
    parser.add_argument("--nsamples", type=int, default=64)
    parser.add_argument("--batch_size", type=int, default=32,
                        help="Batch size for activation capture (larger = faster but more VRAM)")
    parser.add_argument("--seqlen", type=int, default=2048)
    parser.add_argument("--init_dist", action="store_true")
    parser.add_argument("--master_port_base", type=int, default=29500)
    parser.add_argument("--no_cache", action="store_true",
                        help="Disable caching of metrics (recompute even if cache exists)")

    args = parser.parse_args()

    run_a = Path(args.run_a)
    run_b = Path(args.run_b)
    output_path = Path(args.output) if args.output else run_a / "activation_mse_compare.png"
    figsize = tuple(map(float, args.figsize.split(",")))

    plot_comparison(
        run_a=run_a, run_b=run_b, output_path=output_path,
        label_a=args.label_a, label_b=args.label_b, title=args.title,
        figsize=figsize, merge_inputs=not args.no_merge,
        nsamples=args.nsamples, batch_size=args.batch_size, seqlen=args.seqlen,
        init_dist=args.init_dist, master_port_base=args.master_port_base,
        use_cache=not args.no_cache,
    )


if __name__ == "__main__":
    main()
