"""
Compute pseudo-rank of attention matrix products (W_QK and W_VP) for models with MHA.

This script computes the pseudo-rank metric for:
- Key-Query product: W_QK = W_K^T @ W_Q
- Value-Projection product: W_VP = P @ W_V

across all layers and heads of models using standard Multi-Head Attention (MHA).

Usage:
    python compute_pseudo_rank_olmo2.py --model <huggingface-model-name>
    python compute_pseudo_rank_olmo2.py --model <model> --revision <branch>
    python compute_pseudo_rank_olmo2.py --model <model> --layers 5
    python compute_pseudo_rank_olmo2.py --model <model> --sample-heads 3
"""

import argparse
import time
from pathlib import Path
from dataclasses import asdict
from typing import Optional

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoConfig

from transformer_pseudo_rank import (
    compute_svd_and_pseudo_rank,
    compute_rank_ratio,
    compute_relative_reduction,
    AttentionConfig,
    HeadResult,
    ModelResult,
    print_computation_summary,
    print_model_summary,
    save_results,
    aggregate_head_results,
    sample_head_indices,
)


# =============================================================================
# MHA Weight Extraction
# =============================================================================

def extract_attention_weights_per_head(
    layer,
    attn_config: AttentionConfig
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
    """
    Extract per-head Q, K, V, and O projection weights from a layer with MHA.

    Standard Multi-Head Attention (MHA), where each Q head has its
    own K and V head.

    Args:
        layer: A transformer layer (model.layers[i])
        attn_config: Attention configuration

    Returns:
        Tuple of (W_Q_heads, W_K_heads, W_V_heads, W_O_heads)
        Each is a list of tensors, one per attention head.
        - W_Q[h]: shape (head_dim, hidden_size)
        - W_K[h]: shape (head_dim, hidden_size)
        - W_V[h]: shape (head_dim, hidden_size)
        - W_O[h]: shape (hidden_size, head_dim)
    """
    # Get weight matrices
    W_Q_full = layer.self_attn.q_proj.weight  # (n_heads * head_dim, hidden_size)
    W_K_full = layer.self_attn.k_proj.weight  # (n_heads * head_dim, hidden_size)
    W_V_full = layer.self_attn.v_proj.weight  # (n_heads * head_dim, hidden_size)
    W_O_full = layer.self_attn.o_proj.weight  # (hidden_size, n_heads * head_dim)

    n_heads = attn_config.num_attention_heads
    head_dim = attn_config.head_dim

    # Reshape Q weights: (n_heads * head_dim, hidden_size) -> (n_heads, head_dim, hidden_size)
    W_Q_reshaped = W_Q_full.view(n_heads, head_dim, -1)

    # Reshape K weights: (n_heads * head_dim, hidden_size) -> (n_heads, head_dim, hidden_size)
    W_K_reshaped = W_K_full.view(n_heads, head_dim, -1)

    # Reshape V weights: (n_heads * head_dim, hidden_size) -> (n_heads, head_dim, hidden_size)
    W_V_reshaped = W_V_full.view(n_heads, head_dim, -1)

    # Reshape O weights: (hidden_size, n_heads * head_dim) -> (hidden_size, n_heads, head_dim)
    # Then transpose to get (n_heads, hidden_size, head_dim)
    W_O_reshaped = W_O_full.view(-1, n_heads, head_dim).permute(1, 0, 2)

    # Extract per-head weights (MHA: one-to-one mapping)
    W_Q_heads = [W_Q_reshaped[h] for h in range(n_heads)]  # Each: (head_dim, hidden_size)
    W_K_heads = [W_K_reshaped[h] for h in range(n_heads)]  # Each: (head_dim, hidden_size)
    W_V_heads = [W_V_reshaped[h] for h in range(n_heads)]  # Each: (head_dim, hidden_size)
    W_O_heads = [W_O_reshaped[h] for h in range(n_heads)]  # Each: (hidden_size, head_dim)

    return W_Q_heads, W_K_heads, W_V_heads, W_O_heads


# =============================================================================
# Analysis Functions
# =============================================================================

def analyze_single_layer(
    model,
    layer_idx: int,
    attn_config: AttentionConfig,
    head_indices: list[int],
    threshold: float = 0.95,
    store_singular_values: bool = False
) -> list[HeadResult]:
    """
    Analyze pseudo-rank for specified heads in a single layer.

    Args:
        model: The transformer model
        layer_idx: Index of the layer to analyze
        attn_config: Attention configuration
        head_indices: List of head indices to analyze
        threshold: Threshold for pseudo-rank computation
        store_singular_values: Whether to store full singular value arrays

    Returns:
        List of HeadResult for each analyzed head
    """
    layer = model.model.layers[layer_idx]
    W_Q_heads, W_K_heads, W_V_heads, W_O_heads = extract_attention_weights_per_head(layer, attn_config)

    results = []

    for h in head_indices:
        # Compute W_QK = W_K^T @ W_Q
        W_QK = W_K_heads[h].T @ W_Q_heads[h]  # (hidden_size, hidden_size)

        # Compute W_VP = P @ W_V
        W_VP = W_O_heads[h] @ W_V_heads[h]  # (hidden_size, hidden_size)

        # Compute pseudo-ranks
        pr_qk, sv_qk = compute_svd_and_pseudo_rank(W_QK, threshold)
        pr_vp, sv_vp = compute_svd_and_pseudo_rank(W_VP, threshold)

        result = HeadResult(
            layer=layer_idx,
            head=h,
            pseudo_rank_qk=pr_qk,
            pseudo_rank_vp=pr_vp,
            singular_values_qk=sv_qk.tolist() if store_singular_values else None,
            singular_values_vp=sv_vp.tolist() if store_singular_values else None
        )
        results.append(result)

    return results


def analyze_model(
    model_name: str,
    revision: str = "main",
    threshold: float = 0.95,
    layers: Optional[list[int]] = None,
    sample_heads: Optional[int] = None,
    store_singular_values: bool = False,
    verbose: bool = True
) -> ModelResult:
    """
    Analyze pseudo-rank for all layers and heads of a model with MHA.

    Args:
        model_name: HuggingFace model name
        revision: Model revision/branch (default: "main")
        threshold: Threshold for pseudo-rank computation
        layers: Specific layers to analyze (None = all layers)
        sample_heads: Number of heads to sample per layer (None = all heads)
        store_singular_values: Whether to store full singular value arrays
        verbose: Print progress

    Returns:
        ModelResult with all analysis results
    """
    if verbose:
        print(f"Loading model: {model_name} (revision: {revision})")

    # Load config and model
    config = AutoConfig.from_pretrained(model_name, revision=revision)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        revision=revision,
        torch_dtype=torch.float32,
        device_map="cpu"
    )
    model.eval()

    attn_config = AttentionConfig.from_model_config(config)

    # Compute baseline pseudo-ranks
    baseline_qk, baseline_vp = attn_config.get_baseline_pseudo_ranks(threshold)

    if verbose:
        print(f"  model_type: {config.model_type}")
        print(f"  hidden_size: {attn_config.hidden_size}")
        print(f"  num_attention_heads: {attn_config.num_attention_heads}")
        print(f"  num_key_value_heads: {attn_config.num_key_value_heads}")
        print(f"  head_dim: {attn_config.head_dim}")
        print(f"  num_layers: {config.num_hidden_layers}")
        if attn_config.is_gqa:
            print(f"  attention_type: GQA (ratio: {attn_config.num_groups})")
        else:
            print(f"  attention_type: MHA")

    # Determine which layers to analyze
    num_layers = config.num_hidden_layers
    if layers is None:
        layers_to_analyze = list(range(num_layers))
    else:
        layers_to_analyze = [l for l in layers if 0 <= l < num_layers]

    # Determine which heads to sample
    head_indices = sample_head_indices(attn_config.num_attention_heads, sample_heads)

    # Print computation summary
    if verbose:
        print_computation_summary(
            n_layers=len(layers_to_analyze),
            layer_indices=layers_to_analyze,
            n_heads=attn_config.num_attention_heads,
            matrix_dim=attn_config.hidden_size,
            sample_heads=sample_heads,
            baseline_qk=baseline_qk,
            baseline_vp=baseline_vp
        )
        if sample_heads:
            print(f"  Sampled head indices: {head_indices}\n")

    # Analyze each layer
    all_head_results = []
    total_start_time = time.time()

    for layer_idx in layers_to_analyze:
        layer_start_time = time.time()
        if verbose:
            print(f"  Analyzing layer {layer_idx}/{num_layers-1}...", end=" ", flush=True)

        head_results = analyze_single_layer(
            model, layer_idx, attn_config, head_indices, threshold, store_singular_values
        )
        all_head_results.extend(head_results)

        layer_time = time.time() - layer_start_time
        if verbose:
            avg_qk = np.mean([r.pseudo_rank_qk for r in head_results])
            avg_vp = np.mean([r.pseudo_rank_vp for r in head_results])
            ratio_qk = compute_rank_ratio(avg_qk, baseline_qk)
            ratio_vp = compute_rank_ratio(avg_vp, baseline_vp)
            print(f"done in {layer_time:.2f}s (QK={avg_qk:.4f} [{ratio_qk:.0%}], VP={avg_vp:.4f} [{ratio_vp:.0%}])")

    # Aggregate results
    (layer_results, global_mean_qk, global_mean_vp, global_std_qk, global_std_vp,
     global_min_qk, global_max_qk, global_min_vp, global_max_vp) = \
        aggregate_head_results(all_head_results, baseline_qk, baseline_vp, sample_heads)

    # Compute global rank ratios and reductions
    rank_ratio_qk = compute_rank_ratio(global_mean_qk, baseline_qk)
    rank_ratio_vp = compute_rank_ratio(global_mean_vp, baseline_vp)
    relative_reduction_qk = compute_relative_reduction(global_mean_qk, baseline_qk)
    relative_reduction_vp = compute_relative_reduction(global_mean_vp, baseline_vp)

    # Include revision in model name for result tracking
    full_model_name = f"{model_name}@{revision}" if revision != "main" else model_name

    result = ModelResult(
        model_name=full_model_name,
        hidden_size=attn_config.hidden_size,
        num_layers=num_layers,
        num_attention_heads=attn_config.num_attention_heads,
        num_key_value_heads=attn_config.num_key_value_heads,
        head_dim=attn_config.head_dim,
        head_results=[asdict(r) for r in all_head_results],
        layer_results=[asdict(r) for r in layer_results],
        global_mean_qk=global_mean_qk,
        global_mean_vp=global_mean_vp,
        global_std_qk=global_std_qk,
        global_std_vp=global_std_vp,
        global_min_qk=global_min_qk,
        global_max_qk=global_max_qk,
        global_min_vp=global_min_vp,
        global_max_vp=global_max_vp,
        sample_heads=sample_heads,
        baseline_qk=baseline_qk,
        baseline_vp=baseline_vp,
        rank_ratio_qk=rank_ratio_qk,
        rank_ratio_vp=rank_ratio_vp,
        relative_reduction_qk=relative_reduction_qk,
        relative_reduction_vp=relative_reduction_vp
    )

    total_time = time.time() - total_start_time
    if verbose:
        print(f"\n  === RESULTS ===")
        print(f"  Total analysis time: {total_time:.2f}s")
        print(f"  Average time per layer: {total_time / len(layers_to_analyze):.2f}s")
        print(f"  Baseline pseudo-rank: QK={baseline_qk:.4f}, VP={baseline_vp:.4f}")
        print(f"  Measured pseudo-rank: QK={global_mean_qk:.4f}, VP={global_mean_vp:.4f}")
        print(f"  Rank ratio:           QK={rank_ratio_qk:.2%}, VP={rank_ratio_vp:.2%}")
        print(f"  Relative reduction:   QK={relative_reduction_qk:.1%}, VP={relative_reduction_vp:.1%}")

    # Clean up
    del model
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return result


# =============================================================================
# Main CLI
# =============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Compute pseudo-rank of attention matrix products for models with MHA."
    )

    parser.add_argument(
        "--model", "-m",
        type=str,
        required=True,
        help="HuggingFace model name"
    )

    parser.add_argument(
        "--revision", "-r",
        type=str,
        default="main",
        help="Model revision/branch (default: main)"
    )

    parser.add_argument(
        "--layers", "-l",
        type=str,
        default=None,
        help="Comma-separated list of layer indices to analyze (e.g., '5' or '0,5,10'). Default: all layers."
    )

    parser.add_argument(
        "--sample-heads", "-s",
        type=int,
        default=None,
        help="Number of heads to randomly sample per layer (default: all heads). Use e.g. 3 for faster estimates."
    )

    parser.add_argument(
        "--threshold", "-t",
        type=float,
        default=0.95,
        help="Threshold for pseudo-rank computation (default: 0.95)"
    )

    parser.add_argument(
        "--output-dir", "-o",
        type=str,
        default="results",
        help="Directory to save results (default: results)"
    )

    parser.add_argument(
        "--store-singular-values",
        action="store_true",
        help="Store full singular value arrays in output (increases file size)"
    )

    parser.add_argument(
        "--quiet", "-q",
        action="store_true",
        help="Suppress progress output"
    )

    args = parser.parse_args()

    # Parse layers
    layers = None
    if args.layers:
        layers = [int(l.strip()) for l in args.layers.split(",")]

    output_dir = Path(args.output_dir)
    verbose = not args.quiet

    print(f"\n{'#'*60}")
    print(f"# Analyzing: {args.model}")
    if args.revision != "main":
        print(f"# Revision: {args.revision}")
    print(f"{'#'*60}")

    result = analyze_model(
        model_name=args.model,
        revision=args.revision,
        threshold=args.threshold,
        layers=layers,
        sample_heads=args.sample_heads,
        store_singular_values=args.store_singular_values,
        verbose=verbose
    )

    print_model_summary(result)
    save_results(result, output_dir)


if __name__ == "__main__":
    main()
