#!/usr/bin/env python3
"""
ManifoldKV Multi-Key Retrieval Ablation
ICML 2026 - Reproduces Table 4 (Multi-Key Advantage)

Key Finding: ManifoldKV outperforms KeyDiff by +15.4 points on niah_multikey_3
at 50% compression due to directional collision prevention.

Expected Results:
    multikey_3 (50%): ManifoldKV 92.4% vs KeyDiff 77.0% (+15.4)
    multikey_2 (50%): ManifoldKV 99.8% vs KeyDiff 92.6% (+7.2)
"""

import argparse
import json
import sys
from pathlib import Path

import torch
from datasets import load_dataset
from tqdm import tqdm

sys.path.insert(0, str(Path(__file__).parent.parent))

from kvpress import KVPressTextGenerationPipeline, AdaKVPress, ManifoldKVSnapKVScorerPress
from kvpress.presses.keydiff_press import KeyDiffPress
from transformers import AutoModelForCausalLM, AutoTokenizer


def load_model(model_name="meta-llama/Meta-Llama-3.1-8B-Instruct"):
    """Load model."""
    print(f"Loading model: {model_name}...")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
        device_map="auto",
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    return KVPressTextGenerationPipeline(model=model, tokenizer=tokenizer)


def evaluate_task(pipe, dataset, press, task_filter=None, max_samples=None):
    """Evaluate on specific tasks."""
    results = {}
    
    # Filter by task
    if task_filter:
        dataset = dataset.filter(lambda x: x.get("task", "") in task_filter)
    
    samples = list(dataset)[:max_samples] if max_samples else list(dataset)
    
    # Group by task
    from collections import defaultdict
    task_samples = defaultdict(list)
    for s in samples:
        task_samples[s.get("task", "unknown")].append(s)
    
    for task, task_data in task_samples.items():
        correct = 0
        total = 0
        
        for sample in tqdm(task_data, desc=f"{task}"):
            try:
                output = pipe(
                    sample["context"],
                    question=sample["question"],
                    answer_prefix=sample.get("answer_prefix", ""),
                    press=press,
                    max_new_tokens=50,
                )
                
                generated = output["answer"].strip().lower()
                expected = sample["answer"]
                if isinstance(expected, str):
                    expected = [expected]
                
                if any(str(ans).lower() in generated for ans in expected):
                    correct += 1
                total += 1
            except Exception as e:
                total += 1
        
        results[task] = correct / total if total > 0 else 0.0
    
    return results


def main():
    parser = argparse.ArgumentParser(description="Multi-Key Retrieval Ablation")
    parser.add_argument("--context", type=int, default=8192)
    parser.add_argument("--max_samples", type=int, default=None)
    parser.add_argument("--output_dir", type=str, default="../results/multikey")
    args = parser.parse_args()
    
    # Target tasks
    MULTIKEY_TASKS = ["niah_multikey_1", "niah_multikey_2", "niah_multikey_3"]
    COMPRESSION_RATIOS = [0.30, 0.40, 0.50]
    
    pipe = load_model()
    
    # Load dataset
    print("Loading RULER dataset...")
    ds = load_dataset("simonjegou/ruler", data_dir=str(args.context), split="test")
    
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    all_results = {}
    
    for compression in COMPRESSION_RATIOS:
        print(f"\n=== Compression = {compression} ===")
        
        # ManifoldKV
        press_manifold = AdaKVPress(ManifoldKVSnapKVScorerPress())
        press_manifold.compression_ratio = compression
        results_manifold = evaluate_task(pipe, ds, press_manifold, MULTIKEY_TASKS, args.max_samples)
        
        # KeyDiff
        press_keydiff = AdaKVPress(KeyDiffPress())
        press_keydiff.compression_ratio = compression
        results_keydiff = evaluate_task(pipe, ds, press_keydiff, MULTIKEY_TASKS, args.max_samples)
        
        all_results[f"compression_{compression}"] = {
            "manifold_kv": results_manifold,
            "keydiff": results_keydiff,
        }
        
        # Print comparison
        print(f"\nResults at {compression} compression:")
        print(f"{'Task':<20} {'ManifoldKV':>12} {'KeyDiff':>12} {'Δ':>10}")
        print("-" * 60)
        for task in MULTIKEY_TASKS:
            m_acc = results_manifold.get(task, 0) * 100
            k_acc = results_keydiff.get(task, 0) * 100
            delta = m_acc - k_acc
            print(f"{task:<20} {m_acc:>11.1f}% {k_acc:>11.1f}% {delta:>+9.1f}")
    
    # Save results
    with open(output_dir / "multikey_results.json", "w") as f:
        json.dump(all_results, f, indent=2)
    
    # Summary
    print("\n" + "="*60)
    print("MULTI-KEY ABLATION SUMMARY")
    print("="*60)
    print("\nKey Finding: ManifoldKV's L2 distance preserves magnitude,")
    print("preventing 'directional collision' when multiple important")
    print("tokens point in similar directions.")
    print(f"\nResults saved to: {output_dir / 'multikey_results.json'}")


if __name__ == "__main__":
    main()
