import torch
import numpy as np
import transformers
from transformers import AutoModelForCausalLM
import matplotlib.pyplot as plt
import gc

# ==========================================
# 1. Configuration
# ==========================================

TARGET_LAYER_IDX = 12
TARGET_MODULE = "q_proj" # Options: q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj

# Pair 1: RL Run (High Sparsity, Low Rotation)
# Base: DeepSeek-R1-Distill-Qwen-1.5B -> Tuned: NVIDIA-Nemotron-Reasoning
RL_PAIR = {
    "name": "RL Run (DeepSeek -> NVIDIA)",
    "base": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    "tuned": "nvidia/Nemotron-Research-Reasoning-Qwen-1.5B"
}


# Pair 2: SFT Run (Low Sparsity, High Rotation)
# Base: Qwen-2.5-Math-1.5B -> Tuned: DeepSeek-R1-Distill-Qwen-1.5B
SFT_PAIR = {
    "name": "SFT Run (Qwen-Math -> DeepSeek)",
    "base": "Qwen/Qwen2.5-Math-1.5B",
    "tuned": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
}

# BF16 Awareness Thresholds
# rtol=1e-3 approximates the precision limits of BFloat16 accumulation
RTOL = 1e-3
ATOL = 0.0

# ==========================================
# 2. Geometric Analysis Functions
# ==========================================

def principal_angles(U1, U2):
    """Given orthonormal bases U1, U2, compute principal angles in radians."""
    # Ensure matrices are float64 for SVD stability
    U1 = U1.astype(np.float64)
    U2 = U2.astype(np.float64)
    M = U1.T @ U2
    sig = np.linalg.svd(M, compute_uv=False)
    sig = np.clip(sig, -1.0, 1.0)
    return np.arccos(sig)

def full_svd(A, k=128):
    """Compute truncated SVD for the top-k components."""
    U, s, Vt = np.linalg.svd(A.astype(np.float64), full_matrices=False)
    k = min(k, len(s))
    return U[:, :k], s[:k], Vt[:k, :]

def analyze_geometry(W1, W2, rank=128):
    """Computes Spectral Drift and Principal Angle Rotation."""
    print(f"   [Geometry] Computing SVD (Rank {rank})...")
    U1, s1, Vt1 = full_svd(W1, k=rank)
    U2, s2, Vt2 = full_svd(W2, k=rank)

    # 1. Singular Value Drift (Max Diff)
    sv_diff = np.abs(s1 - s2)
    max_sv_diff = sv_diff.max()
    
    # 2. Principal Angle Rotation (Degrees)
    # We look at rotation in the U (output) space and V (input) space
    angles_U = principal_angles(U1, U2) * (180.0 / np.pi)
    angles_V = principal_angles(Vt1.T, Vt2.T) * (180.0 / np.pi)
    
    max_angle_U = angles_U.max()
    max_angle_V = angles_V.max()

    return max_sv_diff, max_angle_U, max_angle_V

def check_sparsity(w0, w1, atol=0, rtol=1e-3):
    """Computes BF16-aware update sparsity."""
    total_elements = w0.size
    
    # Calculate difference
    delta = np.abs(w1 - w0)
    
    # Threshold based on magnitude of weights (simulating precision limits)
    scale = np.maximum(np.abs(w0), np.abs(w1))
    threshold = atol + scale * rtol
    
    mask_unchanged = delta <= threshold
    num_unchanged = np.sum(mask_unchanged)
    sparsity_percent = (num_unchanged / total_elements) * 100.0
    
    return sparsity_percent

# ==========================================
# 3. Main Execution
# ==========================================

def get_layer_weight(model_name, layer_idx, module_name):
    """Loads a specific model, extracts one weight matrix, and clears RAM."""
    print(f"   Loading {model_name}...")
    # Load to CPU to avoid OOM on smaller reviews environments
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.bfloat16, 
        device_map="cpu", 
        trust_remote_code=True
    )
    
    # Extract specific layer weight
    layer = model.model.layers[layer_idx]
    module = getattr(layer.self_attn, module_name)
    weight = module.weight.detach().float().numpy()
    
    # Clean up
    del model
    gc.collect()
    torch.cuda.empty_cache()
    
    return weight

def run_comparison(pair_config):
    print(f"\n{'='*60}")
    print(f"Running Analysis: {pair_config['name']}")
    print(f"{'='*60}")
    
    # Load weights sequentially to save RAM
    w_base = get_layer_weight(pair_config['base'], TARGET_LAYER_IDX, TARGET_MODULE)
    w_tuned = get_layer_weight(pair_config['tuned'], TARGET_LAYER_IDX, TARGET_MODULE)
    
    print("\n   [Analysis] Comparing Weights...")
    
    # 1. Check Sparsity
    sparsity = check_sparsity(w_base, w_tuned, atol=ATOL, rtol=RTOL)
    
    # 2. Check Geometry
    sv_diff, rot_u, rot_v = analyze_geometry(w_base, w_tuned, rank=512)
    
    print(f"\n   >>> RESULTS for {TARGET_MODULE} (Layer {TARGET_LAYER_IDX}) <<<")
    print(f"   --------------------------------------------------")
    print(f"   BF16-Aware Sparsity:       {sparsity:.4f}%")
    print(f"   Max Singular Value Diff:   {sv_diff:.6f}")
    print(f"   Max Principal Rotation:    {rot_u:.4f}° (U) / {rot_v:.4f}° (V)")
    print(f"   --------------------------------------------------")
    
    return sparsity, sv_diff, rot_u

if __name__ == "__main__":
    print("Starting Geometric Verification Script...")
    
    # Run SFT Baseline
    # sft_sparsity, sft_sv, sft_rot = run_comparison(SFT_PAIR)
    
    # Run RL Verification
    rl_sparsity, rl_sv, rl_rot = run_comparison(RL_PAIR)
    
    print("\n\n" + "="*60)
    print("FINAL SUMMARY COMPARISON")
    print("="*60)
    print(f"{'Metric':<25} | {'SFT Run':<15} | {'RL Run':<15}")
    print("-" * 60)
    print(f"{'Sparsity (%)':<25} | {sft_sparsity:<15.2f} | {rl_sparsity:<15.2f}")
    print(f"{'Max Spectral Drift':<25} | {sft_sv:<15.4f} | {rl_sv:<15.4f}")
    print(f"{'Max Rotation (deg)':<25} | {sft_rot:<15.2f} | {rl_rot:<15.2f}")
    print("="*60)
    print("Verification Complete.")