#!/usr/bin/env python3
"""
ManifoldKV Manifold Dimension Analysis
ICML 2026 - Reproduces Table 5 (Universal ~9D Manifold)

Key Finding: Key vectors occupy a universal ~9-dimensional manifold
regardless of architecture (8.2-8.9 across all models).

This validates the theoretical foundation of ManifoldKV.
"""

import argparse
import json
import os
import sys
from pathlib import Path
from datetime import datetime

import torch
import numpy as np
from tqdm import tqdm
from scipy.spatial.distance import cdist
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

from transformers import AutoModelForCausalLM, AutoTokenizer


def load_model(model_name: str, device: str = "cuda"):
    """Load model and tokenizer."""
    print(f"Loading model: {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map=device,
        trust_remote_code=True,
        attn_implementation="sdpa"
    )
    model.eval()
    return model, tokenizer


def extract_keys(model, tokenizer, text: str, max_length: int = 4096):
    """Extract key vectors from all layers."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_length)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=False, use_cache=True)
    
    past_key_values = outputs.past_key_values
    
    all_keys = {}
    for layer_idx, kv in enumerate(past_key_values):
        if kv is not None and isinstance(kv, tuple) and len(kv) >= 1:
            keys = kv[0]  # (batch, num_heads, seq_len, head_dim)
            all_keys[layer_idx] = keys.detach().cpu()
    
    return all_keys, inputs['input_ids'].shape[1]


def pca_dimension(keys: np.ndarray, threshold: float = 0.95) -> dict:
    """PCA-based effective dimension."""
    if keys.ndim == 4:
        keys = keys.reshape(-1, keys.shape[-1])
    elif keys.ndim == 3:
        keys = keys.reshape(-1, keys.shape[-1])
    
    # Remove invalid values
    valid_mask = ~(np.isnan(keys).any(axis=1) | np.isinf(keys).any(axis=1))
    keys = keys[valid_mask]
    
    if len(keys) < 10:
        return {"effective_dim": -1, "ambient_dim": keys.shape[-1] if len(keys) > 0 else 0}
    
    pca = PCA()
    pca.fit(keys)
    
    cumsum = np.cumsum(pca.explained_variance_ratio_)
    effective_dim = np.searchsorted(cumsum, threshold) + 1
    
    return {
        "effective_dim_95": int(effective_dim),
        "effective_dim_99": int(np.searchsorted(cumsum, 0.99) + 1),
        "ambient_dim": keys.shape[-1],
        "ratio": float(effective_dim / keys.shape[-1]),
    }


def two_nn_dimension(keys: np.ndarray, n_samples: int = 5000) -> dict:
    """Two-NN intrinsic dimension estimator."""
    if keys.ndim == 4:
        keys = keys.reshape(-1, keys.shape[-1])
    elif keys.ndim == 3:
        keys = keys.reshape(-1, keys.shape[-1])
    
    valid_mask = ~(np.isnan(keys).any(axis=1) | np.isinf(keys).any(axis=1))
    keys = keys[valid_mask]
    
    if len(keys) < 10:
        return {"two_nn_dim": -1.0}
    
    if len(keys) > n_samples:
        indices = np.random.choice(len(keys), n_samples, replace=False)
        keys = keys[indices]
    
    nn = NearestNeighbors(n_neighbors=3, algorithm='ball_tree')
    nn.fit(keys)
    distances, _ = nn.kneighbors(keys)
    
    r1 = distances[:, 1]
    r2 = distances[:, 2]
    
    valid = r1 > 1e-10
    mu = r2[valid] / r1[valid]
    
    if len(mu) > 0:
        intrinsic_dim = len(mu) / np.sum(np.log(mu))
    else:
        intrinsic_dim = float('nan')
    
    return {
        "two_nn_dim": float(intrinsic_dim),
        "mu_mean": float(np.mean(mu)) if len(mu) > 0 else float('nan'),
    }


def mle_dimension(keys: np.ndarray, k: int = 10, n_samples: int = 5000) -> dict:
    """MLE-based intrinsic dimension estimator."""
    if keys.ndim == 4:
        keys = keys.reshape(-1, keys.shape[-1])
    elif keys.ndim == 3:
        keys = keys.reshape(-1, keys.shape[-1])
    
    valid_mask = ~(np.isnan(keys).any(axis=1) | np.isinf(keys).any(axis=1))
    keys = keys[valid_mask]
    
    if len(keys) < k + 10:
        return {"mle_dim_mean": -1.0}
    
    if len(keys) > n_samples:
        indices = np.random.choice(len(keys), n_samples, replace=False)
        keys = keys[indices]
    
    nn = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree')
    nn.fit(keys)
    distances, _ = nn.kneighbors(keys)
    
    distances = distances[:, 1:]
    log_distances = np.log(distances + 1e-10)
    log_dk = log_distances[:, -1:]
    
    local_dims = (k - 1) / np.sum(log_dk - log_distances[:, :-1], axis=1)
    
    return {
        "mle_dim_mean": float(np.mean(local_dims)),
        "mle_dim_std": float(np.std(local_dims)),
    }


def analyze_model(model_name: str, output_dir: Path, context_length: int = 4096):
    """Run full manifold analysis for a model."""
    model, tokenizer = load_model(model_name)
    
    # Generate sample text
    sample_text = " ".join([
        "The transformer architecture revolutionized NLP.",
        "Key-value caching enables efficient generation.",
        "Attention mechanisms compute weighted sums.",
        "Large language models exhibit emergent capabilities.",
    ] * 500)
    
    # Extract keys
    print(f"Extracting keys for context length {context_length}...")
    keys_dict, actual_length = extract_keys(model, tokenizer, sample_text, max_length=context_length)
    print(f"Actual sequence length: {actual_length}")
    
    # Analyze each layer
    results = {
        "model": model_name,
        "context_length": actual_length,
        "timestamp": datetime.now().isoformat(),
        "by_layer": {}
    }
    
    for layer_idx, keys in tqdm(sorted(keys_dict.items()), desc="Analyzing layers"):
        keys_np = keys.numpy().astype(np.float32)
        
        layer_results = {
            "pca": pca_dimension(keys_np),
            "two_nn": two_nn_dimension(keys_np),
            "mle": mle_dimension(keys_np),
            "shape": list(keys_np.shape),
        }
        
        results["by_layer"][f"layer_{layer_idx}"] = layer_results
    
    # Compute summary
    pca_dims = [r["pca"]["effective_dim_95"] for r in results["by_layer"].values() if r["pca"]["effective_dim_95"] > 0]
    two_nn_dims = [r["two_nn"]["two_nn_dim"] for r in results["by_layer"].values() if r["two_nn"]["two_nn_dim"] > 0]
    mle_dims = [r["mle"]["mle_dim_mean"] for r in results["by_layer"].values() if r["mle"]["mle_dim_mean"] > 0]
    
    ambient_dim = 128
    for layer_data in results["by_layer"].values():
        if layer_data["pca"]["ambient_dim"] > 0:
            ambient_dim = layer_data["pca"]["ambient_dim"]
            break
    
    results["summary"] = {
        "pca_dim_mean": float(np.mean(pca_dims)) if pca_dims else -1.0,
        "pca_dim_std": float(np.std(pca_dims)) if pca_dims else 0.0,
        "two_nn_dim_mean": float(np.nanmean(two_nn_dims)) if two_nn_dims else -1.0,
        "two_nn_dim_std": float(np.nanstd(two_nn_dims)) if two_nn_dims else 0.0,
        "mle_dim_mean": float(np.mean(mle_dims)) if mle_dims else -1.0,
        "mle_dim_std": float(np.std(mle_dims)) if mle_dims else 0.0,
        "ambient_dim": ambient_dim,
        "num_layers": len(keys_dict),
    }
    
    # Save results
    output_dir.mkdir(parents=True, exist_ok=True)
    with open(output_dir / "manifold_results.json", "w") as f:
        json.dump(results, f, indent=2)
    
    # Print summary
    print("\n" + "="*60)
    print("MANIFOLD DIMENSION ANALYSIS")
    print("="*60)
    print(f"Model: {model_name}")
    print(f"Context Length: {actual_length}")
    print(f"Ambient Dimension: {ambient_dim}")
    print(f"\nIntrinsic Dimension Estimates:")
    print(f"  PCA (95% var): {results['summary']['pca_dim_mean']:.1f} ± {results['summary']['pca_dim_std']:.1f}")
    print(f"  Two-NN:        {results['summary']['two_nn_dim_mean']:.1f} ± {results['summary']['two_nn_dim_std']:.1f}")
    print(f"  MLE:           {results['summary']['mle_dim_mean']:.1f} ± {results['summary']['mle_dim_std']:.1f}")
    print(f"\nPCA Ratio: {results['summary']['pca_dim_mean']/ambient_dim:.1%}")
    print("="*60)
    
    return results


def main():
    parser = argparse.ArgumentParser(description="Manifold Dimension Analysis")
    parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct")
    parser.add_argument("--context_length", type=int, default=4096)
    parser.add_argument("--output_dir", type=str, default="../results/manifold")
    parser.add_argument("--gpu", type=int, default=0)
    args = parser.parse_args()
    
    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    
    model_name_short = args.model.split("/")[-1]
    output_dir = Path(args.output_dir) / model_name_short
    
    results = analyze_model(args.model, output_dir, args.context_length)
    
    print(f"\nResults saved to: {output_dir}")


if __name__ == "__main__":
    main()
