#!/usr/bin/env python3
"""
ManifoldKV 64K Windowed Experiments
ICML 2026 - Reproduces Table 2 (64K Recovery Results)

Key Finding: WindowedManifoldKV recovers from 35.2% to 84.3% at 64K context
by using local centroids instead of a global centroid (centroid dilution fix).

Expected Results:
    - Global ManifoldKV:  35.2% (centroid dilution)
    - Windowed-4K:        84.3% (+49.1 points recovery)
    - Windowed-8K:        83.9%
    - Windowed-16K:       82.4%
    - KeyDiff:            81.1%
"""

import argparse
import json
import os
import sys
from pathlib import Path

import torch
from datasets import load_dataset
from tqdm import tqdm

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from kvpress import KVPressTextGenerationPipeline
from kvpress.presses.manifold_press import (
    ManifoldKVPress,
    WindowedManifoldKVPress,
    HybridManifoldKVPress,
    NormalizedManifoldKVPress,
)
from kvpress.presses.keydiff_press import KeyDiffPress
from transformers import AutoModelForCausalLM, AutoTokenizer


def load_model(model_name="meta-llama/Meta-Llama-3.1-8B-Instruct", device_map="auto"):
    """Load model with Flash Attention 2."""
    print(f"Loading model: {model_name}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
        device_map=device_map,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    pipe = KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)
    return pipe


def evaluate_method(pipe, dataset, press, max_samples=None):
    """Evaluate a compression method on the dataset."""
    correct = 0
    total = 0
    
    samples = list(dataset)[:max_samples] if max_samples else list(dataset)
    
    for sample in tqdm(samples, desc=f"Evaluating {press.__class__.__name__}"):
        try:
            output = pipe(
                sample["context"],
                question=sample["question"],
                answer_prefix=sample.get("answer_prefix", ""),
                press=press,
                max_new_tokens=50,
            )
            
            generated = output["answer"].strip().lower()
            expected = sample["answer"]
            if isinstance(expected, str):
                expected = [expected]
            
            if any(str(ans).lower() in generated for ans in expected):
                correct += 1
            total += 1
            
        except Exception as e:
            print(f"Error: {e}")
            total += 1
    
    return correct / total if total > 0 else 0.0


def main():
    parser = argparse.ArgumentParser(description="64K Windowed ManifoldKV Experiments")
    parser.add_argument("--model", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct")
    parser.add_argument("--compression", type=float, default=0.25)
    parser.add_argument("--max_samples", type=int, default=None, help="Limit samples for quick testing")
    parser.add_argument("--full", action="store_true", help="Run all samples")
    parser.add_argument("--method", type=str, default="all", 
                       choices=["all", "global", "windowed_4k", "windowed_8k", "windowed_16k", "keydiff", "hybrid", "normalized"])
    parser.add_argument("--window_size", type=int, default=4096)
    parser.add_argument("--output_dir", type=str, default="../results/64k")
    args = parser.parse_args()
    
    # Configuration
    COMPRESSION = args.compression
    WINDOW_SIZES = [4096, 8192, 16384]
    
    # Load model
    pipe = load_model(args.model)
    
    # Load 64K dataset
    print("Loading 64K RULER dataset...")
    ds = load_dataset("simonjegou/ruler", data_dir="65536", split="test")
    
    results = {}
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    methods_to_run = []
    
    if args.method == "all":
        methods_to_run = [
            ("global_manifold", ManifoldKVPress(compression_ratio=COMPRESSION)),
            ("keydiff", KeyDiffPress(compression_ratio=COMPRESSION)),
            ("windowed_4k", WindowedManifoldKVPress(compression_ratio=COMPRESSION, window_size=4096)),
            ("windowed_8k", WindowedManifoldKVPress(compression_ratio=COMPRESSION, window_size=8192)),
            ("windowed_16k", WindowedManifoldKVPress(compression_ratio=COMPRESSION, window_size=16384)),
            ("hybrid_03", HybridManifoldKVPress(compression_ratio=COMPRESSION, global_weight=0.3)),
            ("hybrid_05", HybridManifoldKVPress(compression_ratio=COMPRESSION, global_weight=0.5)),
            ("normalized", NormalizedManifoldKVPress(compression_ratio=COMPRESSION)),
        ]
    elif args.method == "global":
        methods_to_run = [("global_manifold", ManifoldKVPress(compression_ratio=COMPRESSION))]
    elif args.method == "keydiff":
        methods_to_run = [("keydiff", KeyDiffPress(compression_ratio=COMPRESSION))]
    elif args.method.startswith("windowed"):
        methods_to_run = [(f"windowed_{args.window_size}", 
                         WindowedManifoldKVPress(compression_ratio=COMPRESSION, window_size=args.window_size))]
    elif args.method == "hybrid":
        methods_to_run = [
            ("hybrid_03", HybridManifoldKVPress(compression_ratio=COMPRESSION, global_weight=0.3)),
            ("hybrid_05", HybridManifoldKVPress(compression_ratio=COMPRESSION, global_weight=0.5)),
        ]
    elif args.method == "normalized":
        methods_to_run = [("normalized", NormalizedManifoldKVPress(compression_ratio=COMPRESSION))]
    
    # Run experiments
    for name, press in methods_to_run:
        print(f"\n=== Testing {name} ===")
        acc = evaluate_method(pipe, ds, press, args.max_samples)
        results[name] = acc
        print(f"{name}: {acc*100:.2f}%")
        
        # Save intermediate results
        with open(output_dir / "results.json", "w") as f:
            json.dump(results, f, indent=2)
    
    # Print summary
    print("\n" + "="*60)
    print("64K CONTEXT RESULTS SUMMARY")
    print("="*60)
    print(f"{'Method':<25} {'Accuracy':>10} {'vs KeyDiff':>12}")
    print("-"*60)
    
    keydiff_acc = results.get("keydiff", 0.811)  # Use paper value if not run
    
    for method, acc in sorted(results.items(), key=lambda x: -x[1]):
        delta = (acc - keydiff_acc) * 100
        delta_str = f"+{delta:.1f}" if delta >= 0 else f"{delta:.1f}"
        print(f"{method:<25} {acc*100:>9.2f}% {delta_str:>12}")
    
    # Key findings
    if "global_manifold" in results and "windowed_4k" in results:
        recovery = (results["windowed_4k"] - results["global_manifold"]) * 100
        print("\n" + "="*60)
        print("KEY FINDINGS:")
        print(f"  Centroid Dilution Recovery: +{recovery:.1f} points")
        if "keydiff" in results:
            improvement = (results["windowed_4k"] - results["keydiff"]) * 100
            print(f"  Improvement over KeyDiff:   +{improvement:.1f} points")
    print("="*60)
    
    # Save final results
    with open(output_dir / "results.json", "w") as f:
        json.dump(results, f, indent=2)
    
    print(f"\nResults saved to: {output_dir / 'results.json'}")


if __name__ == "__main__":
    main()
