#!/usr/bin/env python3
"""
Script to evaluate checkpoints from replay experiences on multiple datasets.

This script evaluates:
- apibench-replay checkpoints on: apibench
- mllm-replay checkpoints on: apibench, mllm
- olympus-1-replay checkpoints on: apibench, mllm, olympus-1
- olympus-2-replay checkpoints on: apibench, mllm, olympus-1, olympus-2
"""

import argparse
import os
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Optional


# Define the evaluation mapping
EVALUATION_MAPPING = {
    "apibench-replay": ["apibench"],
    "mllm-replay": ["apibench", "mllm"],
    "olympus-1-replay": ["apibench", "mllm", "olympus-1"],
    "olympus-2-replay": ["apibench", "mllm", "olympus-1", "olympus-2"],
}

EXPERIMENTS_DIR = Path(__file__).resolve().parent / "experiments"
PROJECT_ROOT = Path(__file__).resolve().parent.parent


def find_checkpoints(experience_dir: Path) -> List[str]:
    """
    Find all checkpoint directories in the given experience directory.
    
    Args:
        experience_dir: Path to the experience directory
        
    Returns:
        List of checkpoint directory names (e.g., ["checkpoint-310", "checkpoint-500"])
    """
    if not experience_dir.exists():
        return []
    
    checkpoints = []
    for item in experience_dir.iterdir():
        if item.is_dir() and item.name.startswith("checkpoint-"):
            checkpoints.append(item.name)
    
    # Sort by checkpoint number
    checkpoints.sort(key=lambda x: int(x.split("-")[1]) if x.split("-")[1].isdigit() else 0)
    return checkpoints


def find_latest_checkpoint(experience_dir: Path) -> Optional[str]:
    """
    Find the latest checkpoint in the given experience directory.
    
    Args:
        experience_dir: Path to the experience directory
        
    Returns:
        Latest checkpoint directory name or None if no checkpoints found
    """
    checkpoints = find_checkpoints(experience_dir)
    return checkpoints[-1] if checkpoints else None


def find_experience_dir(experience_name: str, experiments_root: Path) -> Optional[Path]:
    """
    Find the experience directory by name (flexible matching).
    
    Args:
        experience_name: Name of the experience (e.g., "apibench-replay")
        experiments_root: Root directory containing experiments
        
    Returns:
        Path to the experience directory or None if not found
    """
    if not experiments_root.exists():
        return None
    
    # Try exact match first
    exact_path = experiments_root / experience_name
    if exact_path.exists():
        return exact_path
    
    # Try flexible matching (case-insensitive, partial match)
    experience_lower = experience_name.lower()
    for item in experiments_root.iterdir():
        if item.is_dir() and experience_lower in item.name.lower():
            return item
    
    return None


def run_evaluation(
    experience_name: str,
    checkpoint_name: str,
    eval_datasets: List[str],
    config_file: Optional[str] = None,
    output_base_dir: Optional[str] = None,
    dry_run: bool = False
) -> Dict[str, bool]:
    """
    Run evaluation for a checkpoint on multiple datasets.
    
    Args:
        experience_name: Name of the experience (e.g., "apibench-replay")
        checkpoint_name: Name of the checkpoint (e.g., "checkpoint-310")
        eval_datasets: List of datasets to evaluate on
        config_file: Path to eval config YAML file (optional)
        output_base_dir: Base directory for output (optional)
        dry_run: If True, only print commands without executing
        
    Returns:
        Dictionary mapping dataset names to success status
    """
    results = {}
    
    # Construct the adapter path
    adapter_path = f"{experience_name}/{checkpoint_name}"
    
    for dataset in eval_datasets:
        print(f"\n{'='*80}")
        print(f"Evaluating {experience_name}/{checkpoint_name} on {dataset}")
        print(f"{'='*80}")
        
        # Build command - use cco-eval entrypoint (defined in pyproject.toml)
        cmd = [
            "cco-eval",
            "--experience_name", dataset,
            "--lora_adapters", adapter_path
        ]
        
        if config_file:
            cmd.extend(["--config", config_file])
        
        if output_base_dir:
            output_name = f"{experience_name}_{checkpoint_name}_{dataset}"
            cmd.extend(["--output_name", output_name])
        
        print(f"Command: {' '.join(cmd)}")
        
        if dry_run:
            print("[DRY RUN] Would execute command above")
            results[dataset] = True
        else:
            try:
                # Run from project root - cco-eval entrypoint handles everything
                result = subprocess.run(
                    cmd,
                    cwd=PROJECT_ROOT,
                    check=True,
                    capture_output=False
                )
                results[dataset] = result.returncode == 0
                if result.returncode == 0:
                    print(f"✓ Successfully evaluated on {dataset}")
                else:
                    print(f"✗ Failed to evaluate on {dataset}")
            except subprocess.CalledProcessError as e:
                print(f"✗ Error evaluating on {dataset}: {e}")
                results[dataset] = False
            except Exception as e:
                print(f"✗ Unexpected error evaluating on {dataset}: {e}")
                results[dataset] = False
    
    return results


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate checkpoints from replay experiences on multiple datasets",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Evaluate all replay experiences with latest checkpoints
  python evaluate_replay_checkpoints.py
  
  # Evaluate specific experience with latest checkpoint
  python evaluate_replay_checkpoints.py --experience apibench-replay
  
  # Evaluate specific experience with specific checkpoint
  python evaluate_replay_checkpoints.py --experience apibench-replay --checkpoint checkpoint-310
  
  # Dry run to see what would be executed
  python evaluate_replay_checkpoints.py --dry-run
        """
    )
    
    parser.add_argument(
        "--experience",
        type=str,
        choices=list(EVALUATION_MAPPING.keys()),
        help="Specific experience to evaluate (default: all)",
    )
    
    parser.add_argument(
        "--checkpoint",
        type=str,
        help="Specific checkpoint to use (default: latest checkpoint)",
    )
    
    parser.add_argument(
        "--experiments-dir",
        type=str,
        default=str(EXPERIMENTS_DIR),
        help=f"Directory containing experiments (default: {EXPERIMENTS_DIR})",
    )
    
    parser.add_argument(
        "--config",
        type=str,
        help="Path to eval config YAML file (optional)",
    )
    
    parser.add_argument(
        "--output-base-dir",
        type=str,
        help="Base directory for output results (optional)",
    )
    
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Print commands without executing them",
    )
    
    parser.add_argument(
        "--list-checkpoints",
        action="store_true",
        help="List available checkpoints for each experience and exit",
    )
    
    args = parser.parse_args()
    
    experiments_root = Path(args.experiments_dir)
    
    # Determine which experiences to evaluate
    if args.experience:
        experiences_to_eval = [args.experience]
    else:
        experiences_to_eval = list(EVALUATION_MAPPING.keys())
    
    # List checkpoints if requested
    if args.list_checkpoints:
        print("Available checkpoints for each experience:")
        print("=" * 80)
        for exp_name in experiences_to_eval:
            exp_dir = find_experience_dir(exp_name, experiments_root)
            if exp_dir:
                checkpoints = find_checkpoints(exp_dir)
                print(f"\n{exp_name}:")
                if checkpoints:
                    for cp in checkpoints:
                        marker = " (latest)" if cp == checkpoints[-1] else ""
                        print(f"  - {cp}{marker}")
                else:
                    print("  No checkpoints found")
            else:
                print(f"\n{exp_name}: Directory not found")
        return
    
    # Run evaluations
    all_results = {}
    
    for experience_name in experiences_to_eval:
        print(f"\n{'#'*80}")
        print(f"Processing experience: {experience_name}")
        print(f"{'#'*80}")
        
        # Find experience directory
        exp_dir = find_experience_dir(experience_name, experiments_root)
        if not exp_dir:
            print(f"✗ Experience directory not found: {experience_name}")
            print(f"  Searched in: {experiments_root}")
            continue
        
        print(f"Found experience directory: {exp_dir}")
        
        # Determine checkpoint to use
        if args.checkpoint:
            checkpoint_name = args.checkpoint
            checkpoint_path = exp_dir / checkpoint_name
            if not checkpoint_path.exists():
                print(f"✗ Checkpoint not found: {checkpoint_name}")
                continue
        else:
            checkpoint_name = find_latest_checkpoint(exp_dir)
            if not checkpoint_name:
                print(f"✗ No checkpoints found in {exp_dir}")
                continue
            print(f"Using latest checkpoint: {checkpoint_name}")
        
        # Get datasets to evaluate on
        eval_datasets = EVALUATION_MAPPING[experience_name]
        print(f"Will evaluate on datasets: {', '.join(eval_datasets)}")
        
        # Run evaluations
        results = run_evaluation(
            experience_name=experience_name,
            checkpoint_name=checkpoint_name,
            eval_datasets=eval_datasets,
            config_file=args.config,
            output_base_dir=args.output_base_dir,
            dry_run=args.dry_run
        )
        
        all_results[experience_name] = results
    
    # Print summary
    print(f"\n{'#'*80}")
    print("EVALUATION SUMMARY")
    print(f"{'#'*80}")
    
    for exp_name, results in all_results.items():
        print(f"\n{exp_name}:")
        for dataset, success in results.items():
            status = "✓" if success else "✗"
            print(f"  {status} {dataset}")
    
    # Exit with error if any evaluation failed
    all_success = all(
        all(success for success in results.values())
        for results in all_results.values()
    )
    
    sys.exit(0 if all_success else 1)


if __name__ == "__main__":
    main()

