"""Compare activation MSE between two quantization runs (e.g., Qronos vs no-Qronos).

Plots relative MSE from both runs side-by-side for comparison.

Usage:
    python scripts/plot_activation_mse_compare.py --run_a /path/to/qronos_run --run_b /path/to/noqronos_run
    python scripts/plot_activation_mse_compare.py --run_a /path/to/run1 --run_b /path/to/run2 --output compare.png
"""

import argparse
import json
import os
import pickle
import sys
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def load_qronos_stats(qronos_dir: Path) -> Dict[str, Dict]:
    """Load all Qronos stats from a directory."""
    stats = {}
    for pkl_file in sorted(qronos_dir.glob("*.pkl")):
        with open(pkl_file, "rb") as f:
            data = pickle.load(f)
        name = pkl_file.stem
        stats[name] = data
    return stats


def compute_activation_mse(stats: Dict) -> Tuple[float, float, float]:
    """Compute activation MSE from Qronos stats.

    Returns:
        (mse, rel_mse, correlation): Raw MSE, relative MSE, and correlation
    """
    Sig_X = stats["Sig_X"].double()
    Sig_hX = stats["Sig_hX"].double()
    Sig_X_hX = stats["Sig_X_hX"].double()

    tr_X = Sig_X.trace().item()
    tr_hX = Sig_hX.trace().item()
    tr_X_hX = Sig_X_hX.trace().item()

    # MSE = E[||X - X̂||²] = tr(Sig_X) - 2*tr(Sig_X_hX) + tr(Sig_hX)
    mse = tr_X - 2 * tr_X_hX + tr_hX
    mse = max(0, mse)

    # Relative MSE = sqrt(MSE) / sqrt(tr(Sig_X))
    if tr_X > 0:
        rel_mse = np.sqrt(mse) / np.sqrt(tr_X)
    else:
        rel_mse = 0.0

    # Correlation = tr(Sig_X_hX) / sqrt(tr(Sig_X) * tr(Sig_hX))
    if tr_X > 0 and tr_hX > 0:
        correlation = tr_X_hX / np.sqrt(tr_X * tr_hX)
    else:
        correlation = 0.0

    return mse, rel_mse, correlation


def parse_layer_name(name: str) -> Tuple[int, str, str]:
    """Parse layer name like 'layers.0.attention.wq' into (layer_id, block_type, weight_type)."""
    parts = name.split(".")
    if len(parts) >= 4 and parts[0] == "layers":
        layer_id = int(parts[1])
        block_type = parts[2]
        weight_type = parts[3]
        return layer_id, block_type, weight_type
    return -1, "", name


def get_sort_key(name: str) -> Tuple[int, int, int]:
    """Get sort key for layer ordering."""
    layer_id, block_type, weight_type = parse_layer_name(name)
    block_order = 0 if block_type == "attention" else 1
    weight_order_map = {
        "wq": 0, "wk": 1, "wv": 2, "wo": 3,
        "w1": 0, "w2": 1, "w3": 2,
    }
    weight_order = weight_order_map.get(weight_type, 99)
    return (layer_id, block_order, weight_order)


def get_merged_weight_type(weight_type: str) -> str:
    """Get merged weight type for grouping."""
    if weight_type in ("wq", "wk", "wv"):
        return "qkv"
    elif weight_type in ("w1", "w3"):
        return "w1w3"
    return weight_type


def get_weight_type_color(weight_type: str, variant: str = "a") -> str:
    """Get color for weight type. Variant 'a' uses solid colors, 'b' uses lighter shades."""
    if variant == "a":
        color_map = {
            "qkv": "tab:blue",
            "wo": "tab:cyan",
            "w1w3": "tab:orange",
            "w2": "tab:green",
        }
    else:
        color_map = {
            "qkv": "#6baed6",  # lighter blue
            "wo": "#9ecae1",   # lighter cyan
            "w1w3": "#fdae6b", # lighter orange
            "w2": "#a1d99b",   # lighter green
        }
    return color_map.get(weight_type, "tab:gray")


def process_run(run_dir: Path, merge_inputs: bool = True) -> List[Dict]:
    """Process a run directory and return results."""
    qronos_dir = run_dir / "qronos_stats"
    if not qronos_dir.exists():
        raise FileNotFoundError(f"No qronos_stats directory found in {run_dir}")

    all_stats = load_qronos_stats(qronos_dir)
    if not all_stats:
        raise FileNotFoundError(f"No Qronos stats files found in {qronos_dir}")

    results = []
    for name, stats in all_stats.items():
        mse, rel_mse, corr = compute_activation_mse(stats)
        layer_id, block_type, weight_type = parse_layer_name(name)
        results.append({
            "name": name,
            "mse": mse,
            "rel_mse": rel_mse,
            "correlation": corr,
            "sort_key": get_sort_key(name),
            "weight_type": weight_type,
            "merged_type": get_merged_weight_type(weight_type) if merge_inputs else weight_type,
            "layer_id": layer_id,
        })

    results.sort(key=lambda x: x["sort_key"])

    if merge_inputs:
        from collections import defaultdict
        grouped = defaultdict(list)
        for r in results:
            key = (r["layer_id"], r["merged_type"])
            grouped[key].append(r)

        merged_results = []
        for (layer_id, merged_type), items in grouped.items():
            avg_rel_mse = np.mean([r["rel_mse"] for r in items])
            avg_corr = np.mean([r["correlation"] for r in items])
            merged_results.append({
                "short_name": f"L{layer_id}_{merged_type}",
                "rel_mse": avg_rel_mse,
                "correlation": avg_corr,
                "weight_type": merged_type,
                "layer_id": layer_id,
                "sort_key": (layer_id, 0 if merged_type == "qkv" else 1 if merged_type == "wo" else 2 if merged_type == "w1w3" else 3),
            })
        merged_results.sort(key=lambda x: x["sort_key"])
        return merged_results

    for r in results:
        r["short_name"] = f"L{r['layer_id']}_{r['weight_type']}"
    return results


def plot_comparison(
    run_a: Path,
    run_b: Path,
    output_path: Path = None,
    label_a: str = None,
    label_b: str = None,
    title: str = None,
    figsize: Tuple[float, float] = (22, 7),
    merge_inputs: bool = True,
):
    """Plot activation MSE comparison between two runs."""

    # Process both runs
    results_a = process_run(run_a, merge_inputs=merge_inputs)
    results_b = process_run(run_b, merge_inputs=merge_inputs)

    # Default labels
    if label_a is None:
        label_a = run_a.name
    if label_b is None:
        label_b = run_b.name

    # Ensure both have same layers (by short_name)
    names_a = {r["short_name"] for r in results_a}
    names_b = {r["short_name"] for r in results_b}
    common_names = names_a & names_b

    if names_a != names_b:
        print(f"Warning: Runs have different layers. Using {len(common_names)} common layers.")
        results_a = [r for r in results_a if r["short_name"] in common_names]
        results_b = [r for r in results_b if r["short_name"] in common_names]

    # Build lookup for run B
    b_lookup = {r["short_name"]: r for r in results_b}

    # Extract data
    names = [r["short_name"] for r in results_a]
    rel_mses_a = [r["rel_mse"] * 100 for r in results_a]
    rel_mses_b = [b_lookup[r["short_name"]]["rel_mse"] * 100 for r in results_a]
    weight_types = [r["weight_type"] for r in results_a]

    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=figsize, facecolor='white')
    ax.set_facecolor('white')

    x = np.arange(len(names))
    width = 0.35  # Width of bars

    # Plot as grouped scatter or bars
    weight_type_order = ["qkv", "wo", "w1w3", "w2"] if merge_inputs else ["wq", "wk", "wv", "wo", "w1", "w2", "w3"]

    # Plot run A (filled markers)
    for wt in weight_type_order:
        indices = [i for i, w in enumerate(weight_types) if w == wt]
        if indices:
            color = get_weight_type_color(wt, "a")
            ax.scatter(
                [x[i] - 0.15 for i in indices],
                [rel_mses_a[i] for i in indices],
                c=color,
                s=60,
                alpha=0.9,
                label=f"{wt} ({label_a})",
                edgecolors='none',
                marker='o',
                zorder=3,
            )

    # Plot run B (hollow markers)
    for wt in weight_type_order:
        indices = [i for i, w in enumerate(weight_types) if w == wt]
        if indices:
            color = get_weight_type_color(wt, "a")
            ax.scatter(
                [x[i] + 0.15 for i in indices],
                [rel_mses_b[i] for i in indices],
                c='none',
                s=60,
                alpha=0.9,
                label=f"{wt} ({label_b})",
                edgecolors=color,
                linewidths=1.5,
                marker='o',
                zorder=3,
            )

    # Connect corresponding points with thin lines
    for i in range(len(names)):
        ax.plot([x[i] - 0.15, x[i] + 0.15], [rel_mses_a[i], rel_mses_b[i]],
                color='gray', alpha=0.3, linewidth=0.8, zorder=1)

    # Style
    ax.grid(True, which="major", axis="both", alpha=0.3, linewidth=1, color='lightgray')
    ax.set_axisbelow(True)

    ax.set_ylabel("Relative Activation MSE (%)", fontsize=13)
    ax.set_xlabel("Layer", fontsize=13)
    ax.set_title(
        title or f"Activation MSE Comparison: {label_a} vs {label_b}",
        fontsize=15,
        fontweight='normal',
    )

    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=45, ha="right", fontsize=10)
    ax.tick_params(axis='y', labelsize=10)

    for spine in ax.spines.values():
        spine.set_linewidth(0.8)
        spine.set_color('black')

    # Simplified legend (just show A vs B markers)
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='tab:blue', markersize=8, label=label_a),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='none', markeredgecolor='tab:blue',
               markeredgewidth=1.5, markersize=8, label=label_b),
    ]
    ax.legend(handles=legend_elements, loc="upper right", frameon=True, fancybox=True,
              framealpha=0.95, edgecolor='lightgray', fontsize=11)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
        print(f"Saved plot to {output_path}")
    else:
        plt.show()

    plt.close(fig)

    # Print summary
    print(f"\nSummary:")
    print(f"  {label_a}: mean={np.mean(rel_mses_a):.2f}%, max={max(rel_mses_a):.2f}%")
    print(f"  {label_b}: mean={np.mean(rel_mses_b):.2f}%, max={max(rel_mses_b):.2f}%")

    # Compute difference stats
    diffs = [a - b for a, b in zip(rel_mses_a, rel_mses_b)]
    print(f"\n  Difference ({label_a} - {label_b}):")
    print(f"    mean={np.mean(diffs):.2f}%, min={min(diffs):.2f}%, max={max(diffs):.2f}%")

    # Layers where A is better (lower MSE)
    a_better = [(names[i], diffs[i]) for i in range(len(names)) if diffs[i] < -1]
    if a_better:
        print(f"\n  Layers where {label_a} is better (>1% lower MSE):")
        for name, diff in sorted(a_better, key=lambda x: x[1])[:10]:
            print(f"    {name}: {diff:.2f}%")

    # Layers where B is better
    b_better = [(names[i], diffs[i]) for i in range(len(names)) if diffs[i] > 1]
    if b_better:
        print(f"\n  Layers where {label_b} is better (>1% lower MSE):")
        for name, diff in sorted(b_better, key=lambda x: -x[1])[:10]:
            print(f"    {name}: {diff:.2f}%")

    # Save comparison data to JSON
    json_path = output_path.with_suffix('.json') if output_path else run_a.parent / "activation_mse_compare.json"
    comparison_data = {
        "run_a": str(run_a),
        "run_b": str(run_b),
        "label_a": label_a,
        "label_b": label_b,
        "layers": [{
            "name": names[i],
            "rel_mse_a_pct": rel_mses_a[i],
            "rel_mse_b_pct": rel_mses_b[i],
            "diff_pct": diffs[i],
        } for i in range(len(names))],
        "summary": {
            "mean_a_pct": float(np.mean(rel_mses_a)),
            "mean_b_pct": float(np.mean(rel_mses_b)),
            "mean_diff_pct": float(np.mean(diffs)),
        }
    }
    with open(json_path, "w") as f:
        json.dump(comparison_data, f, indent=2)
    print(f"\nSaved JSON data to {json_path}")


def main():
    parser = argparse.ArgumentParser(description="Compare activation MSE between two quantization runs")
    parser.add_argument("--run_a", type=str, required=True,
                        help="Path to first run directory (e.g., Qronos run)")
    parser.add_argument("--run_b", type=str, required=True,
                        help="Path to second run directory (e.g., no-Qronos run)")
    parser.add_argument("--label_a", type=str, default=None,
                        help="Label for run A (default: directory name)")
    parser.add_argument("--label_b", type=str, default=None,
                        help="Label for run B (default: directory name)")
    parser.add_argument("--output", type=str, default=None,
                        help="Output path for plot (default: {run_a}/activation_mse_compare.png)")
    parser.add_argument("--title", type=str, default=None,
                        help="Custom plot title")
    parser.add_argument("--no_merge", action="store_true",
                        help="Don't merge weights with identical inputs")
    parser.add_argument("--figsize", type=str, default="22,7",
                        help="Figure size as 'width,height' (default: 22,7)")

    args = parser.parse_args()

    run_a = Path(args.run_a)
    run_b = Path(args.run_b)

    if not run_a.exists():
        print(f"Error: Run directory not found: {run_a}")
        sys.exit(1)
    if not run_b.exists():
        print(f"Error: Run directory not found: {run_b}")
        sys.exit(1)

    output_path = Path(args.output) if args.output else run_a / "activation_mse_compare.png"
    figsize = tuple(map(float, args.figsize.split(",")))

    plot_comparison(
        run_a=run_a,
        run_b=run_b,
        output_path=output_path,
        label_a=args.label_a,
        label_b=args.label_b,
        title=args.title,
        figsize=figsize,
        merge_inputs=not args.no_merge,
    )


if __name__ == "__main__":
    main()
