#!/bin/bash
# ManifoldKV RULER Benchmark Reproduction
# ICML 2026 - Reproduces Table 1 (Main Results)
#
# Expected Results:
#   AdaKV + ManifoldKV: 95.73%
#   AdaKV + KeyDiff:    95.66%
#   SnapKV:             83.97%

set -e

# Configuration
MODEL="meta-llama/Meta-Llama-3.1-8B-Instruct"
OUTPUT_DIR="../results/ruler"
COMPRESSION=0.20

echo "=============================================="
echo "ManifoldKV RULER Benchmark Reproduction"
echo "=============================================="
echo "Model: $MODEL"
echo "Compression: $COMPRESSION"
echo "Output: $OUTPUT_DIR"
echo ""

mkdir -p $OUTPUT_DIR

# Function to run single experiment
run_experiment() {
    local method=$1
    local context=$2
    local gpu=${3:-0}
    local compression=${4:-$COMPRESSION}
    
    echo "[GPU $gpu] Running $method at ${context} context..."
    
    CUDA_VISIBLE_DEVICES=$gpu python ../evaluation/evaluate.py \
        --dataset ruler \
        --data_dir 4096 \
        --model $MODEL \
        --press_name $method \
        --compression_ratio $compression \
        --max_context_length $context \
        --output_dir $OUTPUT_DIR \
        2>&1 | tee -a $OUTPUT_DIR/log_${method}_${context}.txt
    
    echo "Done: $method at $context"
}

echo ""
echo "=== MAIN RESULTS (Table 1) ==="
echo ""

# ManifoldKV (Our Method)
echo "--- ManifoldKV with AdaKV ---"
run_experiment "adakv_manifold_kv" 4096 0
run_experiment "adakv_manifold_kv" 8192 0
run_experiment "adakv_manifold_kv" 16384 0
run_experiment "adakv_manifold_kv" 32768 0

# KeyDiff Baseline
echo "--- KeyDiff Baseline ---"
run_experiment "adakv_keydiff" 4096 0
run_experiment "adakv_keydiff" 8192 0
run_experiment "adakv_keydiff" 16384 0
run_experiment "adakv_keydiff" 32768 0

# SnapKV Baseline
echo "--- SnapKV Baseline ---"
run_experiment "adakv_snapkv" 4096 0
run_experiment "adakv_snapkv" 8192 0
run_experiment "adakv_snapkv" 16384 0
run_experiment "adakv_snapkv" 32768 0

# Standalone methods
echo "--- Standalone Methods ---"
run_experiment "keydiff" 4096 0
run_experiment "snapkv" 4096 0
run_experiment "manifold_kv" 4096 0

echo ""
echo "=== COMPRESSION RATIO ABLATION (Table 3) ==="
echo ""

for ratio in 0.10 0.15 0.20 0.25 0.30 0.40 0.50; do
    echo "Running compression=$ratio..."
    run_experiment "adakv_manifold_kv" 32768 0 $ratio
done

echo ""
echo "=============================================="
echo "All experiments complete!"
echo "Results saved in: $OUTPUT_DIR"
echo "=============================================="

# Print summary
echo ""
echo "=== RESULTS SUMMARY ==="
python -c "
import json
import os
from pathlib import Path

results_dir = Path('$OUTPUT_DIR')
for d in sorted(results_dir.iterdir()):
    metrics_file = d / 'metrics.json'
    if metrics_file.exists():
        with open(metrics_file) as f:
            m = json.load(f)
        avg = sum(v.get('string_match', 0) for v in m.values() if isinstance(v, dict)) / len(m)
        name = d.name.split('__')
        method = name[3] if len(name) > 3 else 'unknown'
        ctx = name[-1].replace('max_context', '') if 'max_context' in name[-1] else '4096'
        print(f'{method:25s} ctx={ctx:6s} -> {avg:.2f}%')
"
