"""
Shared primitives for computing pseudo-rank of attention matrix products in transformer models.

This module provides:
- SVD computation and pseudo-rank calculation
- Baseline pseudo-rank computation (architectural maximum)
- Common data classes for results
- Output formatting utilities
"""

import json
import random
import time
from pathlib import Path
from dataclasses import dataclass, asdict, field
from typing import Optional

import torch
import numpy as np


# =============================================================================
# Pseudo-Rank Computation
# =============================================================================

def compute_pseudo_rank(singular_values: np.ndarray, threshold: float = 0.95) -> float:
    """
    Compute pseudo-rank from singular values.

    Pseudo-rank is defined as k/n where k is the smallest number of singular
    values needed to capture `threshold` (e.g., 95%) of the total sum.

    Args:
        singular_values: Array of singular values in descending order
        threshold: Fraction of total sum to capture (default 0.95)

    Returns:
        Pseudo-rank in range (0, 1]
    """
    total_sum = singular_values.sum()

    if total_sum == 0:
        return 0.0

    cumsum = np.cumsum(singular_values)
    ratios = cumsum / total_sum

    # Find first index where ratio >= threshold
    k = np.searchsorted(ratios, threshold, side='left') + 1

    return k / len(singular_values)


def compute_svd_and_pseudo_rank(matrix: torch.Tensor, threshold: float = 0.95) -> tuple[float, np.ndarray]:
    """
    Compute SVD and pseudo-rank for a matrix.

    Args:
        matrix: Input matrix (will be converted to float64 for numerical stability)
        threshold: Threshold for pseudo-rank computation

    Returns:
        Tuple of (pseudo_rank, singular_values)
    """
    # Convert to float64 for numerical stability
    matrix_f64 = matrix.double().cpu()

    # Compute singular values only (more efficient than full SVD)
    singular_values = torch.linalg.svdvals(matrix_f64).detach().numpy()

    pseudo_rank = compute_pseudo_rank(singular_values, threshold)

    return pseudo_rank, singular_values


# =============================================================================
# Baseline Pseudo-Rank Computation
# =============================================================================

def compute_baseline_pseudo_rank(
    n_attention_heads: int,
    n_kv_heads: int,
    hidden_size: int,
    threshold: float = 0.95
) -> tuple[float, float]:
    """
    Compute the theoretical baseline pseudo-rank at random initialization.

    The matrix products W_QK and W_VP have an inherent rank bottleneck due to
    the head dimension. At random initialization, we expect to need ~threshold
    fraction of the non-zero singular values to capture threshold of the mass.

    Args:
        n_attention_heads: Number of query/attention heads
        n_kv_heads: Number of key-value heads (same as n_attention_heads for MHA)
        hidden_size: Model hidden dimension (d_model)
        threshold: Cumulative singular value threshold (default 0.95)

    Returns:
        Tuple of (baseline_qk, baseline_vp)
    """
    # Head dimensions
    d_q = hidden_size // n_attention_heads  # Query head dim
    d_k = hidden_size // n_kv_heads  # Key/Value head dim

    # For W_QK = W_K^T @ W_Q:
    # W_K is (d_k, d_model), W_Q is (d_q, d_model)
    # W_K^T is (d_model, d_k)
    # W_QK is (d_model, d_model) but rank <= min(d_k, d_q) = d_k (for GQA where d_k < d_q)
    # For MHA: d_k = d_q, so bottleneck is d_k
    bottleneck_qk = min(d_q, d_k)
    baseline_qk = (threshold * bottleneck_qk) / hidden_size

    # For W_VP = P @ W_V:
    # P is (d_model, d_q), W_V is (d_k, d_model)
    # W_VP is (d_model, d_model) but rank <= min(d_q, d_k)
    bottleneck_vp = min(d_q, d_k)
    baseline_vp = (threshold * bottleneck_vp) / hidden_size

    return baseline_qk, baseline_vp


def compute_rank_ratio(pseudo_rank: float, baseline: float) -> float:
    """
    Compute the ratio of trained pseudo-rank to baseline.

    Args:
        pseudo_rank: Measured pseudo-rank of trained model
        baseline: Baseline pseudo-rank at random initialization

    Returns:
        Ratio in (0, inf) where <1 indicates rank reduction, >1 indicates expansion
    """
    if baseline == 0:
        return float('inf') if pseudo_rank > 0 else 1.0
    return pseudo_rank / baseline


def compute_relative_reduction(pseudo_rank: float, baseline: float) -> float:
    """
    Compute the relative rank reduction compared to baseline.

    Args:
        pseudo_rank: Measured pseudo-rank of trained model
        baseline: Baseline pseudo-rank at random initialization

    Returns:
        Relative reduction as fraction (positive = reduction, negative = expansion)
    """
    if baseline == 0:
        return 0.0
    return (baseline - pseudo_rank) / baseline


# =============================================================================
# Data Classes
# =============================================================================

@dataclass
class HeadResult:
    """Result for a single attention head."""
    layer: int
    head: int
    pseudo_rank_qk: float
    pseudo_rank_vp: float
    singular_values_qk: Optional[list] = None
    singular_values_vp: Optional[list] = None


@dataclass
class LayerResult:
    """Aggregated result for a single layer."""
    layer: int
    mean_pseudo_rank_qk: float
    mean_pseudo_rank_vp: float
    std_pseudo_rank_qk: float
    std_pseudo_rank_vp: float
    min_pseudo_rank_qk: float = 0.0
    max_pseudo_rank_qk: float = 0.0
    min_pseudo_rank_vp: float = 0.0
    max_pseudo_rank_vp: float = 0.0
    n_heads_sampled: int = 0
    # Baseline comparison
    mean_rank_ratio_qk: float = 0.0
    mean_rank_ratio_vp: float = 0.0


@dataclass
class ModelResult:
    """Complete result for a model."""
    model_name: str
    hidden_size: int
    num_layers: int
    num_attention_heads: int
    num_key_value_heads: int
    head_dim: int
    head_results: list
    layer_results: list
    global_mean_qk: float
    global_mean_vp: float
    global_std_qk: float
    global_std_vp: float
    global_min_qk: float = 0.0
    global_max_qk: float = 0.0
    global_min_vp: float = 0.0
    global_max_vp: float = 0.0
    sample_heads: Optional[int] = None
    # Baseline comparison
    baseline_qk: float = 0.0
    baseline_vp: float = 0.0
    rank_ratio_qk: float = 0.0  # global_mean_qk / baseline_qk
    rank_ratio_vp: float = 0.0  # global_mean_vp / baseline_vp
    relative_reduction_qk: float = 0.0  # 1 - rank_ratio_qk
    relative_reduction_vp: float = 0.0  # 1 - rank_ratio_vp


# =============================================================================
# Attention Configuration
# =============================================================================

@dataclass
class AttentionConfig:
    """Configuration for attention weight extraction."""
    hidden_size: int
    num_attention_heads: int  # Number of query heads
    num_key_value_heads: int  # Number of KV heads (may differ for GQA)
    head_dim: int
    num_groups: int  # Query heads per KV head (1 for MHA, >1 for GQA)

    @classmethod
    def from_model_config(cls, config) -> "AttentionConfig":
        hidden_size = config.hidden_size
        num_attention_heads = config.num_attention_heads
        num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
        head_dim = getattr(config, "head_dim", hidden_size // num_attention_heads)
        num_groups = num_attention_heads // num_key_value_heads

        return cls(
            hidden_size=hidden_size,
            num_attention_heads=num_attention_heads,
            num_key_value_heads=num_key_value_heads,
            head_dim=head_dim,
            num_groups=num_groups
        )

    @property
    def is_gqa(self) -> bool:
        """Check if this is Grouped Query Attention."""
        return self.num_groups > 1

    def get_baseline_pseudo_ranks(self, threshold: float = 0.95) -> tuple[float, float]:
        """Get baseline pseudo-ranks for this architecture."""
        return compute_baseline_pseudo_rank(
            self.num_attention_heads,
            self.num_key_value_heads,
            self.hidden_size,
            threshold
        )


# =============================================================================
# Sampling Utilities
# =============================================================================

def sample_head_indices(n_heads: int, sample_size: Optional[int], seed: int = 42) -> list[int]:
    """
    Sample head indices for analysis.

    Args:
        n_heads: Total number of heads
        sample_size: Number of heads to sample (None = all heads)
        seed: Random seed for reproducibility

    Returns:
        List of head indices to analyze
    """
    if sample_size is None or sample_size >= n_heads:
        return list(range(n_heads))

    random.seed(seed)
    return sorted(random.sample(range(n_heads), sample_size))


# =============================================================================
# Output Utilities
# =============================================================================

def print_computation_summary(
    n_layers: int,
    layer_indices: list[int],
    n_heads: int,
    matrix_dim: int,
    sample_heads: Optional[int] = None,
    baseline_qk: Optional[float] = None,
    baseline_vp: Optional[float] = None
):
    """Print a summary of the computation that will be performed."""
    heads_per_layer = sample_heads if sample_heads else n_heads
    total_svd_ops = n_layers * heads_per_layer * 2  # 2 matrices per head (W_QK and W_VP)

    print(f"\n  === COMPUTATION SUMMARY ===")
    print(f"  Layers to analyze: {n_layers} (indices: {layer_indices[0]}..{layer_indices[-1]})")
    if sample_heads:
        print(f"  Heads per layer: {sample_heads} (sampled from {n_heads})")
    else:
        print(f"  Heads per layer: {n_heads}")
    print(f"  Total SVD computations: {total_svd_ops}")
    print(f"    - {n_layers * heads_per_layer} W_QK matrices ({matrix_dim}x{matrix_dim})")
    print(f"    - {n_layers * heads_per_layer} W_VP matrices ({matrix_dim}x{matrix_dim})")
    print(f"  Matrix dimension: {matrix_dim} x {matrix_dim}")
    print(f"  Singular values per matrix: {matrix_dim}")
    if baseline_qk is not None and baseline_vp is not None:
        print(f"  Baseline pseudo-rank (random init): QK={baseline_qk:.4f}, VP={baseline_vp:.4f}")
    print(f"  =============================\n")


def print_model_summary(result: ModelResult):
    """Print a formatted summary of analysis results with baseline comparison."""
    print(f"\n{'='*100}")
    print(f"Model: {result.model_name}")
    print(f"{'='*100}")
    print(f"Architecture: {result.num_layers} layers, {result.num_attention_heads} heads")
    print(f"Hidden size: {result.hidden_size}, Head dim: {result.head_dim}")

    gqa_ratio = result.num_attention_heads // result.num_key_value_heads
    if gqa_ratio > 1:
        print(f"KV heads: {result.num_key_value_heads} (GQA ratio: {gqa_ratio})")
    else:
        print(f"KV heads: {result.num_key_value_heads} (MHA)")

    if result.sample_heads:
        print(f"Sampling: {result.sample_heads} heads per layer")

    # Baseline comparison summary
    print(f"\n--- Baseline Comparison ---")
    print(f"Baseline pseudo-rank (random init): QK={result.baseline_qk:.4f}, VP={result.baseline_vp:.4f}")
    print(f"Measured pseudo-rank:               QK={result.global_mean_qk:.4f}, VP={result.global_mean_vp:.4f}")
    print(f"Rank ratio (measured/baseline):     QK={result.rank_ratio_qk:.2%}, VP={result.rank_ratio_vp:.2%}")
    print(f"Relative reduction:                 QK={result.relative_reduction_qk:.1%}, VP={result.relative_reduction_vp:.1%}")

    # Per-layer table with rank ratios
    print(f"\n--- Per-Layer Results ---")
    print(f"{'Layer':<6} {'QK Mean':<10} {'QK Ratio':<10} {'QK Std':<10} | "
          f"{'VP Mean':<10} {'VP Ratio':<10} {'VP Std':<10}")
    print("-" * 80)

    for lr in result.layer_results:
        qk_ratio = lr['mean_rank_ratio_qk']
        vp_ratio = lr['mean_rank_ratio_vp']
        print(f"{lr['layer']:<6} "
              f"{lr['mean_pseudo_rank_qk']:<10.4f} {qk_ratio:<10.2%} {lr['std_pseudo_rank_qk']:<10.4f} | "
              f"{lr['mean_pseudo_rank_vp']:<10.4f} {vp_ratio:<10.2%} {lr['std_pseudo_rank_vp']:<10.4f}")

    print("-" * 80)
    print(f"{'Global':<6} "
          f"{result.global_mean_qk:<10.4f} {result.rank_ratio_qk:<10.2%} {result.global_std_qk:<10.4f} | "
          f"{result.global_mean_vp:<10.4f} {result.rank_ratio_vp:<10.2%} {result.global_std_vp:<10.4f}")
    print()


def save_results(result: ModelResult, output_dir: Path) -> Path:
    """Save results to JSON file."""
    output_dir.mkdir(parents=True, exist_ok=True)

    # Create a safe filename from model name
    safe_name = result.model_name.replace("/", "_").replace(":", "_").replace("@", "_")
    output_path = output_dir / f"{safe_name}_pseudo_rank.json"

    with open(output_path, "w") as f:
        json.dump(asdict(result), f, indent=2)

    print(f"Results saved to: {output_path}")
    return output_path


def aggregate_head_results(
    head_results: list[HeadResult],
    baseline_qk: float,
    baseline_vp: float,
    sample_heads: Optional[int] = None
) -> tuple[list[LayerResult], float, float, float, float, float, float, float, float]:
    """
    Aggregate head results into layer results and global statistics.

    Returns:
        Tuple of (layer_results, global_mean_qk, global_mean_vp, global_std_qk, global_std_vp,
                  global_min_qk, global_max_qk, global_min_vp, global_max_vp)
    """
    # Group by layer
    layers = {}
    for hr in head_results:
        if hr.layer not in layers:
            layers[hr.layer] = []
        layers[hr.layer].append(hr)

    layer_results = []
    for layer_idx in sorted(layers.keys()):
        heads = layers[layer_idx]
        pr_qk_values = [h.pseudo_rank_qk for h in heads]
        pr_vp_values = [h.pseudo_rank_vp for h in heads]

        mean_qk = np.mean(pr_qk_values)
        mean_vp = np.mean(pr_vp_values)

        layer_result = LayerResult(
            layer=layer_idx,
            mean_pseudo_rank_qk=mean_qk,
            mean_pseudo_rank_vp=mean_vp,
            std_pseudo_rank_qk=np.std(pr_qk_values),
            std_pseudo_rank_vp=np.std(pr_vp_values),
            min_pseudo_rank_qk=np.min(pr_qk_values),
            max_pseudo_rank_qk=np.max(pr_qk_values),
            min_pseudo_rank_vp=np.min(pr_vp_values),
            max_pseudo_rank_vp=np.max(pr_vp_values),
            n_heads_sampled=len(heads),
            mean_rank_ratio_qk=compute_rank_ratio(mean_qk, baseline_qk),
            mean_rank_ratio_vp=compute_rank_ratio(mean_vp, baseline_vp)
        )
        layer_results.append(layer_result)

    # Global statistics
    all_qk = [h.pseudo_rank_qk for h in head_results]
    all_vp = [h.pseudo_rank_vp for h in head_results]

    return (
        layer_results,
        np.mean(all_qk),
        np.mean(all_vp),
        np.std(all_qk),
        np.std(all_vp),
        np.min(all_qk),
        np.max(all_qk),
        np.min(all_vp),
        np.max(all_vp)
    )
