#!/usr/bin/env python3
"""
Phase 1: Priority Empirical Experiments
Launch all experiments in parallel on 8 GPUs

Estimated time: 4-5 hours for full RULER benchmark
"""

import subprocess
import os
import sys
import time
from pathlib import Path

# Configuration
MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# Use relative paths from repository root
REPO_ROOT = Path(__file__).parent.parent.resolve()
OUTPUT_DIR = str(REPO_ROOT / "results/phase1")
EVAL_SCRIPT = str(REPO_ROOT / "evaluation/evaluate.py")

# Create output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

def run_experiment(gpu, method, context, compression=0.20, suffix=""):
    """Launch a single experiment on specified GPU."""
    log_file = f"{OUTPUT_DIR}/log_gpu{gpu}_{method}_{context}{suffix}.txt"
    
    cmd = f"""
cd {REPO_ROOT} && \
source .venv/bin/activate && \
CUDA_VISIBLE_DEVICES={gpu} python evaluation/evaluate.py \
    --dataset ruler \
    --data_dir 4096 \
    --model {MODEL} \
    --press_name {method} \
    --compression_ratio {compression} \
    --max_context_length {context} \
    --output_dir {OUTPUT_DIR} \
    2>&1 | tee {log_file}
"""
    return cmd

# Define experiment batches per GPU
GPU_EXPERIMENTS = {
    # GPU 0: ManifoldKV with AdaKV - All contexts
    0: [
        ("adakv_manifold_kv", 4096, 0.20),
        ("adakv_manifold_kv", 8192, 0.20),
        ("adakv_manifold_kv", 16384, 0.20),
        ("adakv_manifold_kv", 32768, 0.20),
    ],
    
    # GPU 1: KeyDiff with AdaKV - All contexts
    1: [
        ("adakv_keydiff", 4096, 0.20),
        ("adakv_keydiff", 8192, 0.20),
        ("adakv_keydiff", 16384, 0.20),
        ("adakv_keydiff", 32768, 0.20),
    ],
    
    # GPU 2: SnapKV with AdaKV - All contexts
    2: [
        ("adakv_snapkv", 4096, 0.20),
        ("adakv_snapkv", 8192, 0.20),
        ("adakv_snapkv", 16384, 0.20),
        ("adakv_snapkv", 32768, 0.20),
    ],
    
    # GPU 3: Standalone methods
    3: [
        ("keydiff", 4096, 0.20),
        ("keydiff", 8192, 0.20),
        ("snapkv", 4096, 0.20),
        ("snapkv", 8192, 0.20),
        ("manifold_kv", 4096, 0.20),
        ("manifold_kv", 8192, 0.20),
    ],
    
    # GPU 4: Compression ratio ablation
    4: [
        ("adakv_manifold_kv", 32768, 0.10),
        ("adakv_manifold_kv", 32768, 0.15),
        ("adakv_manifold_kv", 32768, 0.25),
        ("adakv_manifold_kv", 32768, 0.30),
        ("adakv_manifold_kv", 32768, 0.40),
        ("adakv_manifold_kv", 32768, 0.50),
    ],
    
    # GPU 5: Multi-key experiments at aggressive compression
    5: [
        ("adakv_manifold_kv", 8192, 0.40),
        ("adakv_manifold_kv", 8192, 0.50),
        ("adakv_keydiff", 8192, 0.40),
        ("adakv_keydiff", 8192, 0.50),
    ],
    
    # GPU 6: Additional context length tests
    6: [
        ("adakv_manifold_kv", 65536, 0.25),
        ("adakv_keydiff", 65536, 0.25),
    ],
    
    # GPU 7: Distance metric ablations
    7: [
        ("adakv_manifold_kv_l1", 8192, 0.20),
        ("adakv_manifold_kv_linf", 8192, 0.20),
        ("manifold_kv_l1", 8192, 0.20),
        ("manifold_kv_linf", 8192, 0.20),
    ],
}

def generate_gpu_script(gpu, experiments):
    """Generate a bash script for a single GPU."""
    script = f"""#!/bin/bash
# GPU {gpu} Experiments
cd {REPO_ROOT}
source .venv/bin/activate
export CUDA_VISIBLE_DEVICES={gpu}

echo "=== GPU {gpu} Starting at $(date) ==="
"""
    
    for method, context, compression in experiments:
        suffix = f"_cr{compression}" if compression != 0.20 else ""
        log_file = f"{OUTPUT_DIR}/log_gpu{gpu}_{method}_{context}{suffix}.txt"
        
        script += f"""
echo "[GPU {gpu}] Running {method} at {context} context, compression={compression}..."
python evaluation/evaluate.py \\
    --dataset ruler \\
    --data_dir 4096 \\
    --model {MODEL} \\
    --press_name {method} \\
    --compression_ratio {compression} \\
    --max_context_length {context} \\
    --output_dir {OUTPUT_DIR} \\
    2>&1 | tee {log_file}
echo "[GPU {gpu}] Done: {method} at {context}"
"""
    
    script += f"""
echo "=== GPU {gpu} Completed at $(date) ==="
"""
    return script

if __name__ == "__main__":
    print("="*60)
    print("Phase 1: Empirical Experiments Launch Plan")
    print("="*60)
    
    for gpu, experiments in sorted(GPU_EXPERIMENTS.items()):
        print(f"\nGPU {gpu}: {len(experiments)} experiments")
        for method, context, compression in experiments:
            print(f"  - {method} @ {context} ctx, cr={compression}")
    
    print("\n" + "="*60)
    print("Generating GPU scripts...")
    
    scripts_dir = Path(OUTPUT_DIR) / "scripts"
    scripts_dir.mkdir(parents=True, exist_ok=True)
    
    for gpu, experiments in GPU_EXPERIMENTS.items():
        script_content = generate_gpu_script(gpu, experiments)
        script_path = scripts_dir / f"run_gpu{gpu}.sh"
        with open(script_path, 'w') as f:
            f.write(script_content)
        os.chmod(script_path, 0o755)
        print(f"  Created: {script_path}")
    
    print("\n" + "="*60)
    print("To launch all experiments:")
    print(f"  cd {OUTPUT_DIR}/scripts")
    print("  for i in {0..7}; do nohup ./run_gpu$i.sh &> ../nohup_gpu$i.log & done")
    print("="*60)
