"""
Compute pseudo-rank of attention matrix products (W_QK and W_VP) for LLaMA-style models.

This script computes the pseudo-rank metric for:
- Key-Query product: W_QK = W_K^T @ W_Q
- Value-Projection product: W_VP = P @ W_V

across all layers and heads of LLaMA-style models with GQA (Grouped Query Attention).

Usage:
    python compute_pseudo_rank.py --model <huggingface-model-name>
    python compute_pseudo_rank.py --model <model1> <model2> <model3>
    python compute_pseudo_rank.py --model <model> --sample-heads 3
"""

import argparse
import json
import time
from pathlib import Path
from dataclasses import asdict
from typing import Optional

import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoConfig

from transformer_pseudo_rank import (
    compute_svd_and_pseudo_rank,
    compute_rank_ratio,
    compute_relative_reduction,
    AttentionConfig,
    HeadResult,
    ModelResult,
    print_computation_summary,
    print_model_summary,
    save_results,
    aggregate_head_results,
    sample_head_indices,
)


# =============================================================================
# LLaMA-specific Weight Extraction (with GQA support)
# =============================================================================

def extract_attention_weights_per_head(
    layer,
    attn_config: AttentionConfig
) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]:
    """
    Extract per-head Q, K, V, and O projection weights from a LLaMA layer.

    For GQA models, K and V heads are repeated to match Q heads.

    Args:
        layer: A transformer layer (model.layers[i])
        attn_config: Attention configuration

    Returns:
        Tuple of (W_Q_heads, W_K_heads, W_V_heads, W_O_heads)
        Each is a list of tensors, one per attention head.
        - W_Q[h]: shape (head_dim, hidden_size)
        - W_K[h]: shape (head_dim, hidden_size)
        - W_V[h]: shape (head_dim, hidden_size)
        - W_O[h]: shape (hidden_size, head_dim)
    """
    # Get weight matrices
    W_Q_full = layer.self_attn.q_proj.weight  # (n_heads * head_dim, hidden_size)
    W_K_full = layer.self_attn.k_proj.weight  # (n_kv_heads * head_dim, hidden_size)
    W_V_full = layer.self_attn.v_proj.weight  # (n_kv_heads * head_dim, hidden_size)
    W_O_full = layer.self_attn.o_proj.weight  # (hidden_size, n_heads * head_dim)

    n_heads = attn_config.num_attention_heads
    n_kv_heads = attn_config.num_key_value_heads
    head_dim = attn_config.head_dim
    num_groups = attn_config.num_groups

    # Reshape Q weights: (n_heads * head_dim, hidden_size) -> (n_heads, head_dim, hidden_size)
    W_Q_reshaped = W_Q_full.view(n_heads, head_dim, -1)

    # Reshape KV weights: (n_kv_heads * head_dim, hidden_size) -> (n_kv_heads, head_dim, hidden_size)
    W_K_reshaped = W_K_full.view(n_kv_heads, head_dim, -1)
    W_V_reshaped = W_V_full.view(n_kv_heads, head_dim, -1)

    # Reshape O weights: (hidden_size, n_heads * head_dim) -> (hidden_size, n_heads, head_dim)
    # Then transpose to get (n_heads, hidden_size, head_dim)
    W_O_reshaped = W_O_full.view(-1, n_heads, head_dim).permute(1, 0, 2)

    # Extract per-head weights
    W_Q_heads = [W_Q_reshaped[h] for h in range(n_heads)]  # Each: (head_dim, hidden_size)
    W_O_heads = [W_O_reshaped[h] for h in range(n_heads)]  # Each: (hidden_size, head_dim)

    # For GQA: map each query head to its corresponding KV head
    W_K_heads = []
    W_V_heads = []
    for h in range(n_heads):
        kv_idx = h // num_groups  # Which KV head this Q head uses
        W_K_heads.append(W_K_reshaped[kv_idx])  # (head_dim, hidden_size)
        W_V_heads.append(W_V_reshaped[kv_idx])  # (head_dim, hidden_size)

    return W_Q_heads, W_K_heads, W_V_heads, W_O_heads


# =============================================================================
# Analysis Functions
# =============================================================================

def analyze_single_layer(
    model,
    layer_idx: int,
    attn_config: AttentionConfig,
    head_indices: list[int],
    threshold: float = 0.95,
    store_singular_values: bool = False
) -> list[HeadResult]:
    """
    Analyze pseudo-rank for specified heads in a single layer.

    Args:
        model: The transformer model
        layer_idx: Index of the layer to analyze
        attn_config: Attention configuration
        head_indices: List of head indices to analyze
        threshold: Threshold for pseudo-rank computation
        store_singular_values: Whether to store full singular value arrays

    Returns:
        List of HeadResult for each analyzed head
    """
    layer = model.model.layers[layer_idx]
    W_Q_heads, W_K_heads, W_V_heads, W_O_heads = extract_attention_weights_per_head(layer, attn_config)

    results = []

    for h in head_indices:
        # Compute W_QK = W_K^T @ W_Q
        W_QK = W_K_heads[h].T @ W_Q_heads[h]  # (hidden_size, hidden_size)

        # Compute W_VP = P @ W_V
        W_VP = W_O_heads[h] @ W_V_heads[h]  # (hidden_size, hidden_size)

        # Compute pseudo-ranks
        pr_qk, sv_qk = compute_svd_and_pseudo_rank(W_QK, threshold)
        pr_vp, sv_vp = compute_svd_and_pseudo_rank(W_VP, threshold)

        result = HeadResult(
            layer=layer_idx,
            head=h,
            pseudo_rank_qk=pr_qk,
            pseudo_rank_vp=pr_vp,
            singular_values_qk=sv_qk.tolist() if store_singular_values else None,
            singular_values_vp=sv_vp.tolist() if store_singular_values else None
        )
        results.append(result)

    return results


def analyze_model(
    model_name: str,
    threshold: float = 0.95,
    layers: Optional[list[int]] = None,
    sample_heads: Optional[int] = None,
    store_singular_values: bool = False,
    verbose: bool = True
) -> ModelResult:
    """
    Analyze pseudo-rank for all layers and heads of a model.

    Args:
        model_name: HuggingFace model name
        threshold: Threshold for pseudo-rank computation
        layers: Specific layers to analyze (None = all layers)
        sample_heads: Number of heads to sample per layer (None = all heads)
        store_singular_values: Whether to store full singular value arrays
        verbose: Print progress

    Returns:
        ModelResult with all analysis results
    """
    if verbose:
        print(f"Loading model: {model_name}")

    # Load config and model
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float32,
        device_map="cpu"
    )
    model.eval()

    attn_config = AttentionConfig.from_model_config(config)

    # Debug info
    print(f"num_attention_heads: {config.num_attention_heads}")
    print(f"num_key_value_heads: {config.num_key_value_heads}")

    # Compute baseline pseudo-ranks
    baseline_qk, baseline_vp = attn_config.get_baseline_pseudo_ranks(threshold)

    if verbose:
        print(f"  model_type: {config.model_type}")
        print(f"  hidden_size: {attn_config.hidden_size}")
        print(f"  num_attention_heads: {attn_config.num_attention_heads}")
        print(f"  num_key_value_heads: {attn_config.num_key_value_heads}")
        print(f"  head_dim: {attn_config.head_dim}")
        print(f"  num_layers: {config.num_hidden_layers}")
        if attn_config.is_gqa:
            print(f"  attention_type: GQA (ratio: {attn_config.num_groups})")
        else:
            print(f"  attention_type: MHA")

    # Determine which layers to analyze
    num_layers = config.num_hidden_layers
    if layers is None:
        layers_to_analyze = list(range(num_layers))
    else:
        layers_to_analyze = [l for l in layers if 0 <= l < num_layers]

    # Determine which heads to sample
    head_indices = sample_head_indices(attn_config.num_attention_heads, sample_heads)

    # Print computation summary
    if verbose:
        print_computation_summary(
            n_layers=len(layers_to_analyze),
            layer_indices=layers_to_analyze,
            n_heads=attn_config.num_attention_heads,
            matrix_dim=attn_config.hidden_size,
            sample_heads=sample_heads,
            baseline_qk=baseline_qk,
            baseline_vp=baseline_vp
        )
        if sample_heads:
            print(f"  Sampled head indices: {head_indices}\n")

    # Analyze each layer
    all_head_results = []
    total_start_time = time.time()

    for layer_idx in layers_to_analyze:
        layer_start_time = time.time()
        if verbose:
            print(f"  Analyzing layer {layer_idx}/{num_layers-1}...", end=" ", flush=True)

        head_results = analyze_single_layer(
            model, layer_idx, attn_config, head_indices, threshold, store_singular_values
        )
        all_head_results.extend(head_results)

        layer_time = time.time() - layer_start_time
        if verbose:
            avg_qk = np.mean([r.pseudo_rank_qk for r in head_results])
            avg_vp = np.mean([r.pseudo_rank_vp for r in head_results])
            ratio_qk = compute_rank_ratio(avg_qk, baseline_qk)
            ratio_vp = compute_rank_ratio(avg_vp, baseline_vp)
            print(f"done in {layer_time:.2f}s (QK={avg_qk:.4f} [{ratio_qk:.0%}], VP={avg_vp:.4f} [{ratio_vp:.0%}])")

    # Aggregate results
    (layer_results, global_mean_qk, global_mean_vp, global_std_qk, global_std_vp,
     global_min_qk, global_max_qk, global_min_vp, global_max_vp) = \
        aggregate_head_results(all_head_results, baseline_qk, baseline_vp, sample_heads)

    # Compute global rank ratios and reductions
    rank_ratio_qk = compute_rank_ratio(global_mean_qk, baseline_qk)
    rank_ratio_vp = compute_rank_ratio(global_mean_vp, baseline_vp)
    relative_reduction_qk = compute_relative_reduction(global_mean_qk, baseline_qk)
    relative_reduction_vp = compute_relative_reduction(global_mean_vp, baseline_vp)

    result = ModelResult(
        model_name=model_name,
        hidden_size=attn_config.hidden_size,
        num_layers=num_layers,
        num_attention_heads=attn_config.num_attention_heads,
        num_key_value_heads=attn_config.num_key_value_heads,
        head_dim=attn_config.head_dim,
        head_results=[asdict(r) for r in all_head_results],
        layer_results=[asdict(r) for r in layer_results],
        global_mean_qk=global_mean_qk,
        global_mean_vp=global_mean_vp,
        global_std_qk=global_std_qk,
        global_std_vp=global_std_vp,
        global_min_qk=global_min_qk,
        global_max_qk=global_max_qk,
        global_min_vp=global_min_vp,
        global_max_vp=global_max_vp,
        sample_heads=sample_heads,
        baseline_qk=baseline_qk,
        baseline_vp=baseline_vp,
        rank_ratio_qk=rank_ratio_qk,
        rank_ratio_vp=rank_ratio_vp,
        relative_reduction_qk=relative_reduction_qk,
        relative_reduction_vp=relative_reduction_vp
    )

    total_time = time.time() - total_start_time
    if verbose:
        print(f"\n  === RESULTS ===")
        print(f"  Total analysis time: {total_time:.2f}s")
        print(f"  Average time per layer: {total_time / len(layers_to_analyze):.2f}s")
        print(f"  Baseline pseudo-rank: QK={baseline_qk:.4f}, VP={baseline_vp:.4f}")
        print(f"  Measured pseudo-rank: QK={global_mean_qk:.4f}, VP={global_mean_vp:.4f}")
        print(f"  Rank ratio:           QK={rank_ratio_qk:.2%}, VP={rank_ratio_vp:.2%}")
        print(f"  Relative reduction:   QK={relative_reduction_qk:.1%}, VP={relative_reduction_vp:.1%}")

    # Clean up
    del model
    torch.cuda.empty_cache() if torch.cuda.is_available() else None

    return result


# =============================================================================
# Main CLI
# =============================================================================

def main():
    parser = argparse.ArgumentParser(
        description="Compute pseudo-rank of attention matrix products for LLaMA-style models."
    )

    parser.add_argument(
        "--model", "-m",
        type=str,
        nargs='+',
        required=True,
        help="HuggingFace model name(s) to analyze. Can specify multiple models."
    )

    parser.add_argument(
        "--layers", "-l",
        type=str,
        default=None,
        help="Comma-separated list of layer indices to analyze (e.g., '5' or '0,5,10'). Default: all layers."
    )

    parser.add_argument(
        "--sample-heads", "-s",
        type=int,
        default=None,
        help="Number of heads to randomly sample per layer (default: all heads). Use e.g. 3 for faster estimates."
    )

    parser.add_argument(
        "--threshold", "-t",
        type=float,
        default=0.95,
        help="Threshold for pseudo-rank computation (default: 0.95)"
    )

    parser.add_argument(
        "--output-dir", "-o",
        type=str,
        default="results",
        help="Directory to save results (default: results)"
    )

    parser.add_argument(
        "--store-singular-values",
        action="store_true",
        help="Store full singular value arrays in output (increases file size)"
    )

    parser.add_argument(
        "--quiet", "-q",
        action="store_true",
        help="Suppress progress output"
    )

    args = parser.parse_args()

    # Parse layers
    layers = None
    if args.layers:
        layers = [int(l.strip()) for l in args.layers.split(",")]

    output_dir = Path(args.output_dir)
    verbose = not args.quiet

    # Analyze each model
    all_results = []
    for model_name in args.model:
        print(f"\n{'#'*60}")
        print(f"# Analyzing: {model_name}")
        print(f"{'#'*60}")

        result = analyze_model(
            model_name=model_name,
            threshold=args.threshold,
            layers=layers,
            sample_heads=args.sample_heads,
            store_singular_values=args.store_singular_values,
            verbose=verbose
        )

        print_model_summary(result)
        save_results(result, output_dir)
        all_results.append(result)

    # Print comparison summary if multiple models
    if len(all_results) > 1:
        print(f"\n{'='*100}")
        print("COMPARISON SUMMARY")
        print(f"{'='*100}")
        print(f"{'Model':<45} {'QK Mean':<10} {'QK Ratio':<10} {'VP Mean':<10} {'VP Ratio':<10}")
        print("-" * 100)

        for r in all_results:
            short_name = r.model_name.split("/")[-1]
            print(f"{short_name:<45} {r.global_mean_qk:<10.4f} {r.rank_ratio_qk:<10.2%} "
                  f"{r.global_mean_vp:<10.4f} {r.rank_ratio_vp:<10.2%}")


if __name__ == "__main__":
    main()
