#!/bin/bash
# ManifoldKV Multi-Model Experiments
# ICML 2026 - Reproduces Table 6 (Cross-Architecture Generalization)
#
# Expected Results:
#   Gemma-3-12B:   95.2% (4K), 94.4% (8K), 95.2% (16K)
#   Qwen3-8B:      95.0% (4K), 94.5% (8K), 95.0% (16K)
#   Ministral-8B:  95.5% (4K), 94.9% (8K), 95.2% (16K)

set -e

OUTPUT_DIR="../results/multimodel"
COMPRESSION=0.20

echo "=============================================="
echo "ManifoldKV Multi-Model Experiments"
echo "=============================================="

mkdir -p $OUTPUT_DIR

# Models to test
MODELS=(
    "google/gemma-3-12b-it"
    "Qwen/Qwen3-8B"
    "mistralai/Ministral-8B-Instruct-2410"
)

# Context lengths
CONTEXTS=(4096 8192 16384)

# Methods
METHODS=("adakv_manifold_kv" "adakv_keydiff" "adakv_snapkv")

for model in "${MODELS[@]}"; do
    echo ""
    echo "=== Testing $model ==="
    
    model_short=$(echo $model | sed 's/.*\///')
    
    for method in "${METHODS[@]}"; do
        for ctx in "${CONTEXTS[@]}"; do
            echo "[$(date)] Running $method on $model_short at ${ctx}..."
            
            CUDA_VISIBLE_DEVICES=0 python ../evaluation/evaluate.py \
                --dataset ruler \
                --data_dir 4096 \
                --model $model \
                --press_name $method \
                --compression_ratio $COMPRESSION \
                --max_context_length $ctx \
                --output_dir $OUTPUT_DIR \
                2>&1 | tee -a $OUTPUT_DIR/log_${model_short}_${method}_${ctx}.txt
        done
    done
done

echo ""
echo "=============================================="
echo "Multi-model experiments complete!"
echo "Results saved in: $OUTPUT_DIR"
echo "=============================================="

# Generate summary
echo ""
echo "=== RESULTS SUMMARY ==="
python -c "
import json
from pathlib import Path
from collections import defaultdict

results = defaultdict(lambda: defaultdict(dict))
results_dir = Path('$OUTPUT_DIR')

for d in results_dir.iterdir():
    if d.is_dir():
        metrics_file = d / 'metrics.json'
        if metrics_file.exists():
            with open(metrics_file) as f:
                m = json.load(f)
            avg = sum(v.get('string_match', 0) for v in m.values() if isinstance(v, dict)) / len(m)
            
            parts = d.name.split('__')
            model = parts[2].replace('--', '/') if len(parts) > 2 else 'unknown'
            method = parts[3] if len(parts) > 3 else 'unknown'
            ctx = parts[-1].replace('max_context', '') if 'max_context' in parts[-1] else '4096'
            
            results[model][method][ctx] = avg

print()
print('Model'.ljust(25), '4K'.rjust(8), '8K'.rjust(8), '16K'.rjust(8))
print('-' * 55)
for model, methods in sorted(results.items()):
    for method, ctxs in methods.items():
        row = f'{model[:20]}_{method[:8]}'.ljust(25)
        for ctx in ['4096', '8192', '16384']:
            acc = ctxs.get(ctx, 0)
            row += f'{acc:>7.1f}%'
        print(row)
"
