"""
MIRROR Ablation Evaluation Script

This script runs the LLM pragmatic benchmark using MIRROR ablation variants:
- Cognitive Controller Only (no threads)
- Threads Only (no cognitive controller)
"""

import os
import sys
import argparse
from datetime import datetime
import subprocess
import concurrent.futures

def run_scenario_subprocess(scenario, args):
    """Runs a single scenario using subprocess (for parallel execution)"""
    print(f"Starting Scenario {scenario} (ablation: {args.ablation_type})...")
    
    # Get the directory containing this script
    script_dir = os.path.dirname(os.path.abspath(__file__))
    # Get the parent directory (project root)
    project_root = os.path.dirname(script_dir)
    
    # Build command to run this script with single scenario
    cmd = [
        "python3", "eval_ablation.py",
        "--ablation-type", args.ablation_type,
        "--mirror-model", args.mirror_model,
        "--scenario", str(scenario),
        "--results-dir", args.results_dir,
        "--model-prefix", args.model_prefix,
        # Don't pass --parallel-scenarios to avoid infinite recursion
    ]
    
    # Add max-examples if provided
    if args.max_examples:
        cmd.extend(["--max-examples", str(args.max_examples)])
    
    # Print what we're running
    print(f"[Scenario {scenario}] Running {args.ablation_type} ablation with model: {args.mirror_model}")
    
    # Run as subprocess and let output stream to console
    try:
        result = subprocess.run(
            cmd, 
            check=True,
            text=True,
            env=os.environ.copy(),  # Pass current environment variables
            cwd=script_dir  # Set working directory to llm_prag_benchmark
        )
        print(f"[Scenario {scenario}] Completed successfully")
        return scenario, True, "Success"
    except subprocess.CalledProcessError as e:
        print(f"[Scenario {scenario}] Failed with exit code: {e.returncode}")
        return scenario, False, f"Error exit code: {e.returncode}"

def main():
    parser = argparse.ArgumentParser(description="Run MIRROR ablation evaluation")
    
    # Model configuration
    parser.add_argument("--mirror-model", type=str, default="openai/gpt-4o",
                       help="Model to use for MIRROR ablation (default: openai/gpt-4o)")
    
    # Ablation type
    parser.add_argument("--ablation-type", type=str, required=True,
                       choices=["cognitive_only", "threads_only"],
                       help="Type of ablation to run")
    
    # Evaluation configuration
    parser.add_argument("--scenario", type=int, default=None,
                       help="Run specific scenario (1-5)")
    parser.add_argument("--max-examples", type=int, default=None,
                       help="Maximum examples per scenario")
    parser.add_argument("--parallel", action="store_true",
                       help="Enable parallel processing")
    parser.add_argument("--workers", type=int, default=6,
                       help="Number of parallel workers")
    parser.add_argument("--parallel-scenarios", action="store_true",
                       help="Run scenarios in parallel (one worker per scenario)")
    
    # Output configuration
    parser.add_argument("--results-dir", type=str, default='',
                       help="Directory to save results")
    parser.add_argument("--model-prefix", type=str, default='',
                       help="Prefix for result files")
    
    args = parser.parse_args()
    
    # Handle parallel scenario execution
    if args.parallel_scenarios:
        # Run scenarios in parallel using subprocesses
        if args.scenario:
            print("Warning: --scenario is ignored when using --parallel-scenarios")
        
        scenarios_to_run = list(range(1, 6))  # Run all 5 scenarios
        total_scenarios = len(scenarios_to_run)
        
        print(f"\n{'='*60}")
        print(f"MIRROR Ablation Evaluation - Parallel Scenarios")
        print(f"{'='*60}")
        print(f"Ablation Type: {args.ablation_type}")
        print(f"Model: {args.mirror_model}")
        print(f"Running {total_scenarios} scenarios in parallel with {args.workers} workers")
        print(f"Results Directory: {args.results_dir or 'default'}")
        print(f"{'='*60}\n")
        
        # Run scenarios in parallel
        with concurrent.futures.ProcessPoolExecutor(max_workers=args.workers) as executor:
            futures = {
                executor.submit(
                    run_scenario_subprocess, 
                    scenario, 
                    args
                ): scenario for scenario in scenarios_to_run
            }
            
            completed = 0
            for future in concurrent.futures.as_completed(futures):
                scenario, success, output = future.result()
                completed += 1
                print(f"\nProgress: {completed}/{total_scenarios} scenarios completed")
        
        print(f"\n{'='*60}")
        print(f"All scenarios completed!")
        print(f"Type: {args.ablation_type}")
        print(f"Results saved with prefix: {args.model_prefix}")
        print(f"{'='*60}\n")
        
        sys.exit(0)  # Exit after parallel execution
    
    # Validate model prefix - ensure it starts with "mirror-" for ablation runs
    if not args.model_prefix:
        args.model_prefix = f"mirror-ablation-{args.ablation_type}-"
    elif not args.model_prefix.startswith("mirror-"):
        # Force mirror prefix for ablation runs
        args.model_prefix = f"mirror-{args.model_prefix}"
    
    print(f"\n{'='*60}")
    print(f"MIRROR Ablation Evaluation")
    print(f"{'='*60}")
    print(f"Ablation Type: {args.ablation_type}")
    print(f"Model: {args.mirror_model}")
    print(f"Model Prefix: {args.model_prefix}")
    print(f"Results Directory: {args.results_dir or 'default'}")
    if args.scenario:
        print(f"Scenario: {args.scenario}")
    if args.max_examples:
        print(f"Max Examples: {args.max_examples}")
    print(f"Parallel Processing: {'Enabled' if args.parallel else 'Disabled'}")
    if args.parallel:
        print(f"Workers: {args.workers}")
    print(f"{'='*60}\n")
    
    # Import eval and run benchmark
    import eval
    
    try:
        # Run the benchmark - pass ablation_type directly as parameter
        eval.run_benchmark(
            mirror_internal_model_arg=args.mirror_model,
            parallel=args.parallel,
            max_workers=args.workers if args.parallel else None,
            single_scenario=args.scenario,
            results_dir=args.results_dir,
            model_prefix=args.model_prefix,
            max_examples=args.max_examples,
            ablation_type=args.ablation_type
        )
        
        print(f"\n{'='*60}")
        print(f"Ablation evaluation complete!")
        print(f"Type: {args.ablation_type}")
        print(f"Results saved with prefix: {args.model_prefix}")
        print(f"{'='*60}\n")
        
    except Exception as e:
        print(f"\nError running ablation evaluation: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

if __name__ == "__main__":
    main() 