#!/usr/bin/env python3
"""
ManifoldKV Sanity Check
ICML 2026 - Quick validation that everything works

Run this first to verify your installation before running full experiments.
"""

import sys
from pathlib import Path

print("="*60)
print("ManifoldKV Sanity Check")
print("="*60)

# Check 1: Imports
print("\n[1/5] Checking imports...")
try:
    import torch
    print(f"  ✓ PyTorch {torch.__version__}")
    
    from transformers import AutoModelForCausalLM, AutoTokenizer
    print("  ✓ Transformers")
    
    sys.path.insert(0, str(Path(__file__).parent.parent))
    
    from kvpress import (
        KVPressTextGenerationPipeline,
        ManifoldKVPress,
        ManifoldKVSnapKVScorerPress,
        WindowedManifoldKVPress,
        AdaKVPress,
        KeyDiffPress,
    )
    print("  ✓ KVPress with ManifoldKV")
except ImportError as e:
    print(f"  ✗ Import error: {e}")
    print("\n  Fix: pip install -e . (from icml_code_repo directory)")
    sys.exit(1)

# Check 2: CUDA
print("\n[2/5] Checking CUDA...")
if torch.cuda.is_available():
    print(f"  ✓ CUDA available: {torch.cuda.get_device_name(0)}")
    print(f"  ✓ Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("  ✗ CUDA not available")
    print("  Note: ManifoldKV requires GPU for reasonable performance")

# Check 3: ManifoldKV Press Instantiation
print("\n[3/5] Testing ManifoldKV press creation...")
try:
    # Standard ManifoldKV
    press1 = ManifoldKVPress(compression_ratio=0.2)
    print(f"  ✓ ManifoldKVPress: {press1}")
    
    # AdaKV + ManifoldKV
    press2 = AdaKVPress(ManifoldKVSnapKVScorerPress())
    press2.compression_ratio = 0.2
    print(f"  ✓ AdaKV + ManifoldKV: compression_ratio={press2.compression_ratio}")
    
    # Windowed ManifoldKV
    press3 = WindowedManifoldKVPress(compression_ratio=0.25, window_size=4096)
    print(f"  ✓ WindowedManifoldKV: window_size={press3.window_size}")
    
    # KeyDiff baseline
    press4 = KeyDiffPress(compression_ratio=0.2)
    print(f"  ✓ KeyDiffPress (baseline)")
    
except Exception as e:
    print(f"  ✗ Press creation error: {e}")
    sys.exit(1)

# Check 4: Core Algorithm Test
print("\n[4/5] Testing core algorithm...")
try:
    import torch
    
    # Simulate key vectors
    bsz, heads, seq_len, head_dim = 1, 8, 1000, 128
    keys = torch.randn(bsz, heads, seq_len, head_dim)
    
    # ManifoldKV scoring: L2 distance from centroid
    mu = keys.mean(dim=2, keepdim=True)
    scores_manifold = torch.norm(keys - mu, dim=-1)
    
    # KeyDiff scoring: Cosine similarity
    import torch.nn.functional as F
    keys_norm = F.normalize(keys, dim=-1)
    anchor = keys_norm.mean(dim=2, keepdim=True)
    scores_keydiff = -F.cosine_similarity(keys, anchor, dim=-1)
    
    # Verify shapes
    assert scores_manifold.shape == (bsz, heads, seq_len), f"Wrong shape: {scores_manifold.shape}"
    assert scores_keydiff.shape == (bsz, heads, seq_len), f"Wrong shape: {scores_keydiff.shape}"
    
    # Top-k selection
    k = int(seq_len * 0.8)  # Keep 80%
    top_indices = torch.topk(scores_manifold, k, dim=-1).indices
    
    print(f"  ✓ ManifoldKV scores: mean={scores_manifold.mean():.3f}, std={scores_manifold.std():.3f}")
    print(f"  ✓ KeyDiff scores: mean={scores_keydiff.mean():.3f}, std={scores_keydiff.std():.3f}")
    print(f"  ✓ Top-{k} selection works")
    
except Exception as e:
    print(f"  ✗ Algorithm test error: {e}")
    sys.exit(1)

# Check 5: Model loading (skip by default, takes time)
print("\n[5/5] Model loading (skipped - run with --full for full check)")
print("  → To test model loading: python sanity_check.py --full")

# Summary
print("\n" + "="*60)
print("SANITY CHECK PASSED!")
print("="*60)
print("\nYou're ready to run experiments:")
print("  1. Main results:    ./scripts/run_ruler_experiments.sh")
print("  2. 64K recovery:    python scripts/run_64k_windowed.py")
print("  3. Multi-model:     ./scripts/run_multimodel_experiments.sh")
print("\nFor quick testing, add --max_samples 10 to limit samples.")

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--full", action="store_true", help="Run full check including model loading")
    args = parser.parse_args()
    
    if args.full:
        print("\n[FULL] Testing model loading...")
        try:
            print("  Loading meta-llama/Meta-Llama-3.1-8B-Instruct...")
            model = AutoModelForCausalLM.from_pretrained(
                "meta-llama/Meta-Llama-3.1-8B-Instruct",
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True,
            )
            tokenizer = AutoTokenizer.from_pretrained(
                "meta-llama/Meta-Llama-3.1-8B-Instruct",
                trust_remote_code=True,
            )
            pipe = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
            print("  ✓ Model loaded successfully")
            
            # Quick inference test
            press = ManifoldKVPress(compression_ratio=0.2)
            output = pipe(
                "The secret password is ABC123. What is the password?",
                questions=["What is the password?"],
                answer_prefix="The password is",
                press=press,
                max_new_tokens=20,
            )
            print(f"  ✓ Inference test: {output['answers'][0][:50]}...")
            
        except Exception as e:
            print(f"  ✗ Model loading error: {e}")
            print("  → Make sure you have access to Llama models (huggingface-cli login)")
