"""Plot actual activation difference between two quantized runs.

Computes ||X_A - X_B|| / ||X_B|| by running both models on the same data.

Usage:
    python scripts/plot_activation_diff.py --run_a /path/to/run_a --run_b /path/to/run_b
    python scripts/plot_activation_diff.py --run_a /path/to/run_a --run_b /path/to/run_b --output diff.png
"""

import argparse
import json
import os
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from quant_layerwise.data import get_wikitext2, split_dataset, take_nseq
from quant_layerwise.partial_model import load_and_apply_manifest
from quant_layerwise.pipeline import ensure_single_process_distributed, load_model_and_tokenizer
from quant_layerwise.storage import RunManifest


def get_layer_modules(model) -> Dict[str, torch.nn.Module]:
    """Get all weight modules we want to track activations for."""
    modules = {}
    for layer_idx, layer in enumerate(model.layers):
        # Attention weights
        modules[f"layers.{layer_idx}.attention.wq"] = layer.attention.wq
        modules[f"layers.{layer_idx}.attention.wk"] = layer.attention.wk
        modules[f"layers.{layer_idx}.attention.wv"] = layer.attention.wv
        modules[f"layers.{layer_idx}.attention.wo"] = layer.attention.wo
        # FFN weights
        modules[f"layers.{layer_idx}.feed_forward.w1"] = layer.feed_forward.w1
        modules[f"layers.{layer_idx}.feed_forward.w2"] = layer.feed_forward.w2
        modules[f"layers.{layer_idx}.feed_forward.w3"] = layer.feed_forward.w3
    return modules


class ActivationCapture:
    """Hook to capture input activations for a module."""

    def __init__(self):
        self.activations: List[torch.Tensor] = []
        self.handle = None

    def hook(self, module, input, output):
        # input is a tuple, first element is the activation
        x = input[0].detach()
        self.activations.append(x)

    def register(self, module: torch.nn.Module):
        self.handle = module.register_forward_hook(self.hook)

    def remove(self):
        if self.handle is not None:
            self.handle.remove()
            self.handle = None

    def get_concatenated(self) -> torch.Tensor:
        """Concatenate all captured activations along batch dimension."""
        if not self.activations:
            return None
        return torch.cat(self.activations, dim=0)

    def clear(self):
        self.activations = []


@torch.no_grad()
def capture_activations(
    model,
    eval_tokens: torch.Tensor,
    module_names: List[str],
    modules: Dict[str, torch.nn.Module],
    batch_size: int = 4,
) -> Dict[str, torch.Tensor]:
    """Run model and capture activations at specified modules."""

    # Set up hooks
    captures = {}
    for name in module_names:
        if name in modules:
            captures[name] = ActivationCapture()
            captures[name].register(modules[name])

    # Run inference
    model.eval()
    device = next(model.parameters()).device

    nsamples = eval_tokens.shape[0]
    for i in range(0, nsamples, batch_size):
        batch = eval_tokens[i:i+batch_size].to(device)
        _ = model(batch)

    # Collect activations
    results = {}
    for name, capture in captures.items():
        results[name] = capture.get_concatenated()
        capture.remove()

    return results


def compute_relative_diff(x_a: torch.Tensor, x_b: torch.Tensor) -> float:
    """Compute ||X_A - X_B|| / ||X_B||."""
    # Flatten to compute norms
    x_a_flat = x_a.reshape(-1).double()
    x_b_flat = x_b.reshape(-1).double()

    diff_norm = torch.norm(x_a_flat - x_b_flat).item()
    b_norm = torch.norm(x_b_flat).item()

    if b_norm > 0:
        return diff_norm / b_norm
    return 0.0


def parse_layer_name(name: str) -> Tuple[int, str, str]:
    """Parse layer name like 'layers.0.attention.wq' into (layer_id, block_type, weight_type)."""
    parts = name.split(".")
    if len(parts) >= 4 and parts[0] == "layers":
        layer_id = int(parts[1])
        block_type = parts[2]
        weight_type = parts[3]
        return layer_id, block_type, weight_type
    return -1, "", name


def get_sort_key(name: str) -> Tuple[int, int, int]:
    """Get sort key for layer ordering."""
    layer_id, block_type, weight_type = parse_layer_name(name)
    block_order = 0 if block_type == "attention" else 1
    weight_order_map = {
        "wq": 0, "wk": 1, "wv": 2, "wo": 3,
        "w1": 0, "w2": 1, "w3": 2,
    }
    weight_order = weight_order_map.get(weight_type, 99)
    return (layer_id, block_order, weight_order)


def get_merged_weight_type(weight_type: str) -> str:
    """Get merged weight type for grouping."""
    if weight_type in ("wq", "wk", "wv"):
        return "qkv"
    elif weight_type in ("w1", "w3"):
        return "w1w3"
    return weight_type


def get_weight_type_color(weight_type: str) -> str:
    """Get color for weight type."""
    color_map = {
        "qkv": "tab:blue",
        "wo": "tab:cyan",
        "w1w3": "tab:orange",
        "w2": "tab:green",
    }
    return color_map.get(weight_type, "tab:gray")


def plot_activation_diff(
    run_a: Path,
    run_b: Path,
    output_path: Path = None,
    label_a: str = None,
    label_b: str = None,
    title: str = None,
    figsize: Tuple[float, float] = (22, 6.5),
    merge_inputs: bool = True,
    nsamples: int = 64,
    batch_size: int = 4,
    seqlen: int = 2048,
    local_rank: int = 0,
    init_dist: bool = False,
    master_port_base: int = 29500,
):
    """Plot ||X_A - X_B|| / ||X_B|| for each layer."""

    if init_dist:
        ensure_single_process_distributed(local_rank=local_rank, master_port=master_port_base)

    # Load manifests
    manifest_a = RunManifest.load(run_a / "manifest.json")
    manifest_b = RunManifest.load(run_b / "manifest.json")

    if manifest_a.model_name != manifest_b.model_name:
        raise ValueError(f"Runs use different models: {manifest_a.model_name} vs {manifest_b.model_name}")

    model_name = manifest_a.model_name

    # Default labels
    if label_a is None:
        label_a = run_a.name
    if label_b is None:
        label_b = run_b.name

    print(f"Comparing activations:")
    print(f"  Run A ({label_a}): {run_a}")
    print(f"  Run B ({label_b}): {run_b}")
    print(f"  Model: {model_name}")

    # Load both models simultaneously
    print(f"\nLoading model A ({label_a})...")
    model_a, tokenizer = load_model_and_tokenizer(model_name, local_rank=local_rank)
    load_and_apply_manifest(model_a, run_a)
    modules_a = get_layer_modules(model_a)

    print(f"Loading model B ({label_b})...")
    model_b, _ = load_model_and_tokenizer(model_name, local_rank=local_rank)
    load_and_apply_manifest(model_b, run_b)
    modules_b = get_layer_modules(model_b)

    # Prepare eval data
    print(f"Preparing evaluation data (nsamples={nsamples}, seqlen={seqlen})...")
    eval_tokens = split_dataset(get_wikitext2(tokenizer, split="test"), seqlen)
    eval_tokens = take_nseq(eval_tokens, nsamples)
    print(f"Using {eval_tokens.shape[0]} samples")

    # Capture activations from both models
    print(f"\nCapturing activations from both models...")
    module_names = list(modules_a.keys())
    activations_a = capture_activations(model_a, eval_tokens, module_names, modules_a, batch_size=batch_size)
    activations_b = capture_activations(model_b, eval_tokens, module_names, modules_b, batch_size=batch_size)

    # Free models
    del model_a, model_b
    torch.cuda.empty_cache()

    # Compute relative differences
    print(f"\nComputing relative differences...")
    results = []
    for name in module_names:
        if name in activations_a and name in activations_b:
            x_a = activations_a[name]
            x_b = activations_b[name]

            if x_a is not None and x_b is not None:
                rel_diff = compute_relative_diff(x_a, x_b)
                layer_id, block_type, weight_type = parse_layer_name(name)
                results.append({
                    "name": name,
                    "rel_diff": rel_diff,
                    "sort_key": get_sort_key(name),
                    "weight_type": weight_type,
                    "merged_type": get_merged_weight_type(weight_type) if merge_inputs else weight_type,
                    "layer_id": layer_id,
                })

    results.sort(key=lambda x: x["sort_key"])

    # Merge if requested
    if merge_inputs:
        grouped = defaultdict(list)
        for r in results:
            key = (r["layer_id"], r["merged_type"])
            grouped[key].append(r)

        merged_results = []
        for (layer_id, merged_type), items in grouped.items():
            avg_rel_diff = np.mean([r["rel_diff"] for r in items])
            merged_results.append({
                "short_name": f"L{layer_id}_{merged_type}",
                "rel_diff": avg_rel_diff,
                "weight_type": merged_type,
                "layer_id": layer_id,
                "sort_key": (layer_id, 0 if merged_type == "qkv" else 1 if merged_type == "wo" else 2 if merged_type == "w1w3" else 3),
            })
        merged_results.sort(key=lambda x: x["sort_key"])
        plot_results = merged_results
    else:
        for r in results:
            r["short_name"] = f"L{r['layer_id']}_{r['weight_type']}"
        plot_results = results

    # Extract data for plotting
    names = [r["short_name"] for r in plot_results]
    rel_diffs = [r["rel_diff"] * 100 for r in plot_results]  # Convert to percentage
    weight_types = [r["weight_type"] for r in plot_results]

    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=figsize, facecolor='white')
    ax.set_facecolor('white')

    x = np.arange(len(names))

    # Plot by weight type
    weight_type_order = ["qkv", "wo", "w1w3", "w2"] if merge_inputs else ["wq", "wk", "wv", "wo", "w1", "w2", "w3"]

    for wt in weight_type_order:
        indices = [i for i, w in enumerate(weight_types) if w == wt]
        if indices:
            color = get_weight_type_color(wt)
            ax.scatter(
                [x[i] for i in indices],
                [rel_diffs[i] for i in indices],
                c=color,
                s=55,
                alpha=0.8,
                label=wt,
                edgecolors='none',
                zorder=3,
            )

    # Style
    ax.grid(True, which="major", axis="both", alpha=0.3, linewidth=1, color='lightgray')
    ax.set_axisbelow(True)

    ax.set_ylabel(f"||X_A - X_B|| / ||X_B|| (%)", fontsize=13)
    ax.set_xlabel("Layer", fontsize=13)
    ax.set_title(
        title or f"Activation Difference: {label_a} vs {label_b}",
        fontsize=15,
        fontweight='normal',
    )

    ax.set_xticks(x)
    ax.set_xticklabels(names, rotation=45, ha="right", fontsize=10)
    ax.tick_params(axis='y', labelsize=10)

    for spine in ax.spines.values():
        spine.set_linewidth(0.8)
        spine.set_color('black')

    ax.legend(loc="upper right", frameon=True, fancybox=True,
              framealpha=0.95, edgecolor='lightgray', fontsize=10)

    plt.tight_layout()

    if output_path:
        plt.savefig(output_path, dpi=150, bbox_inches='tight', facecolor='white')
        print(f"\nSaved plot to {output_path}")
    else:
        plt.show()

    plt.close(fig)

    # Print summary
    print(f"\nSummary ||X_A - X_B|| / ||X_B||:")
    print(f"  Min: {min(rel_diffs):.2f}%")
    print(f"  Max: {max(rel_diffs):.2f}%")
    print(f"  Mean: {np.mean(rel_diffs):.2f}%")

    # High difference layers
    high_diff = [(r["short_name"], r["rel_diff"]*100) for r in plot_results if r["rel_diff"] > 0.05]
    if high_diff:
        print(f"\n  High difference layers (>5%):")
        for name, diff in sorted(high_diff, key=lambda x: -x[1])[:10]:
            print(f"    {name}: {diff:.2f}%")

    # Save to JSON
    json_path = output_path.with_suffix('.json') if output_path else run_a / "activation_diff.json"
    diff_data = {
        "run_a": str(run_a),
        "run_b": str(run_b),
        "label_a": label_a,
        "label_b": label_b,
        "nsamples": nsamples,
        "seqlen": seqlen,
        "layers": [{
            "name": r["short_name"],
            "rel_diff_pct": r["rel_diff"] * 100,
        } for r in plot_results],
        "summary": {
            "min_pct": float(min(rel_diffs)),
            "max_pct": float(max(rel_diffs)),
            "mean_pct": float(np.mean(rel_diffs)),
        }
    }
    with open(json_path, "w") as f:
        json.dump(diff_data, f, indent=2)
    print(f"Saved JSON data to {json_path}")


def main():
    parser = argparse.ArgumentParser(description="Plot activation difference between two quantized runs")
    parser.add_argument("--run_a", type=str, required=True,
                        help="Path to first run directory")
    parser.add_argument("--run_b", type=str, required=True,
                        help="Path to second run directory")
    parser.add_argument("--label_a", type=str, default=None,
                        help="Label for run A (default: directory name)")
    parser.add_argument("--label_b", type=str, default=None,
                        help="Label for run B (default: directory name)")
    parser.add_argument("--output", type=str, default=None,
                        help="Output path for plot (default: {run_a}/activation_diff.png)")
    parser.add_argument("--title", type=str, default=None,
                        help="Custom plot title")
    parser.add_argument("--no_merge", action="store_true",
                        help="Don't merge weights with identical inputs")
    parser.add_argument("--figsize", type=str, default="22,6.5",
                        help="Figure size as 'width,height' (default: 22,6.5)")
    parser.add_argument("--nsamples", type=int, default=64,
                        help="Number of samples for activation capture (default: 64)")
    parser.add_argument("--batch_size", type=int, default=4,
                        help="Batch size for inference (default: 4)")
    parser.add_argument("--seqlen", type=int, default=2048,
                        help="Sequence length (default: 2048)")
    parser.add_argument("--init_dist", action="store_true",
                        help="Initialize distributed environment")
    parser.add_argument("--master_port_base", type=int, default=29500,
                        help="Master port for distributed init")

    args = parser.parse_args()

    run_a = Path(args.run_a)
    run_b = Path(args.run_b)

    if not run_a.exists():
        print(f"Error: Run directory not found: {run_a}")
        sys.exit(1)
    if not run_b.exists():
        print(f"Error: Run directory not found: {run_b}")
        sys.exit(1)

    output_path = Path(args.output) if args.output else run_a / "activation_diff.png"
    figsize = tuple(map(float, args.figsize.split(",")))

    plot_activation_diff(
        run_a=run_a,
        run_b=run_b,
        output_path=output_path,
        label_a=args.label_a,
        label_b=args.label_b,
        title=args.title,
        figsize=figsize,
        merge_inputs=not args.no_merge,
        nsamples=args.nsamples,
        batch_size=args.batch_size,
        seqlen=args.seqlen,
        init_dist=args.init_dist,
        master_port_base=args.master_port_base,
    )


if __name__ == "__main__":
    main()
