#!/usr/bin/env python3
"""
Run cco-eval over all ablations with the specified evaluation pattern.

For each ablation:
- apibench adapter → test on apibench
- mllm adapter → test on apibench, mllm
- hugging-bench-1 adapter → test on apibench, mllm, hugging-bench-1
- hugging-bench-2 adapter → test on apibench, mllm, hugging-bench-1, hugging-bench-2

Results are organized by ablation folder with naming:
- exp:{adapter_latest_experience}_test_set:{tested_on_experience_name}_answers.jsonl
- exp:{adapter_latest_experience}_test_set:{tested_on_experience_name}_metrics.jsonl
"""

import subprocess
import sys
import os
import json
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import argparse
from datetime import datetime

# Constants
SCRIPT_DIR = Path(__file__).parent.absolute()
EXPERIMENTS_ROOT = SCRIPT_DIR / "cco" / "experiments"
ABLATION_CONFIGS_DIR = SCRIPT_DIR / "configurations_carve" / "ablations_mllm_icml2026"
RESULTS_BASE = SCRIPT_DIR / "results" / "ablations"

# Experience sequence
EXPERIENCES = ["apibench", "mllm", "hugging-bench-1", "hugging-bench-2"]

# Ablation definitions (matching run_ablation_full.py)
ABLATIONS = {
    1: {
        "name": "ablate_replay_domain_05",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.05,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    2: {
        "name": "ablate_replay_domain_10",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    3: {
        "name": "ablate_replay_domain_20",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.20,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    4: {
        "name": "ablate_replay_random_05",
        "replay_strategy": "random",
        "replay_percentage": 0.05,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    5: {
        "name": "ablate_replay_random_10",
        "replay_strategy": "random",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    6: {
        "name": "ablate_replay_random_20",
        "replay_strategy": "random",
        "replay_percentage": 0.20,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    7: {
        "name": "ablate_no_emb_anchor",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": False,
        "proj_anchor": True,
    },
    8: {
        "name": "ablate_no_proj_anchor",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": False,
    },
    9: {
        "name": "ablate_cumulative_training",
        "replay_strategy": "random",
        "replay_percentage": 1.00,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
        "cumulative": True,
    },
    10: {
        "name": "ablate_from_scratch_training",
        "replay_strategy": "random",
        "replay_percentage": 1.00,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
        "from_scratch": True,
    },
    11: {
        "name": "ablate_joint_training",
        "replay_strategy": "random",
        "replay_percentage": None,
        "two_phase": False,
        "emb_anchor": False,
        "proj_anchor": False,
        "joint_training": True,
    },
    12: {
        "name": "ablate_replay_domain_10_alt_repo",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
        "repo_id": "deepseek-ai/deepseek-coder-7b-instruct-v1.5",  # TODO: Confirm exact repo_id
    },
}


def find_latest_checkpoint(adapter_dir: Path) -> Optional[str]:
    """Find the latest checkpoint in an adapter directory."""
    if not adapter_dir.exists():
        return None
    
    checkpoints = []
    for item in adapter_dir.iterdir():
        if item.is_dir() and item.name.startswith("checkpoint-"):
            try:
                checkpoint_num = int(item.name.split("-")[1])
                checkpoints.append((checkpoint_num, item.name))
            except (ValueError, IndexError):
                continue
    
    if not checkpoints:
        return None
    
    checkpoints.sort(key=lambda x: x[0], reverse=True)
    return checkpoints[0][1]


def find_adapters_for_ablation(ablation_name: str, seed: Optional[int] = None) -> Dict[str, Optional[str]]:
    """
    Find adapters for each experience for a given ablation.
    
    Uses the naming pattern from run_ablation_full.py: {experience_name}-{ablation_name} or
    {experience_name}-{ablation_name}_seed{seed} if seed is provided.
    
    Special handling for joint training (ablation 11): uses "joint" as experience name.
    
    Args:
        ablation_name: Name of the ablation
        seed: Optional seed value to include in directory name
    
    Returns:
        Dict mapping experience name to adapter path (relative to cco/experiments/)
        or None if adapter doesn't exist
    """
    adapters = {}
    
    # Check if this is joint training (ablation 11)
    is_joint_training = ablation_name == "ablate_joint_training"
    
    if is_joint_training:
        # Joint training uses "joint" as the experience name
        full_ablation_name = f"{ablation_name}_seed{seed}" if seed is not None else ablation_name
        exp_dir_name = f"joint-{full_ablation_name}"
        exp_dir = EXPERIMENTS_ROOT / exp_dir_name
        
        if exp_dir.exists():
            checkpoint = find_latest_checkpoint(exp_dir)
            if checkpoint:
                # For joint training, we use "joint" as the key, but we need to test on all experiences
                # Store it under all experience names since joint training covers all experiences
                for experience_name in EXPERIENCES:
                    adapters[experience_name] = f"{exp_dir_name}/{checkpoint}"
            else:
                for experience_name in EXPERIENCES:
                    adapters[experience_name] = None
        else:
            for experience_name in EXPERIENCES:
                adapters[experience_name] = None
    else:
        # For other ablations, look for adapters per experience
        for experience_name in EXPERIENCES:
            # Construct directory name with seed if provided
            full_ablation_name = f"{ablation_name}_seed{seed}" if seed is not None else ablation_name
            exp_dir_name = f"{experience_name}-{full_ablation_name}"
            exp_dir = EXPERIMENTS_ROOT / exp_dir_name
            
            if exp_dir.exists():
                checkpoint = find_latest_checkpoint(exp_dir)
                if checkpoint:
                    adapters[experience_name] = f"{exp_dir_name}/{checkpoint}"
                else:
                    adapters[experience_name] = None
            else:
                adapters[experience_name] = None
    
    return adapters


def get_all_ablations() -> List[Tuple[str, Optional[str]]]:
    """
    Get all ablation names from the ABLATIONS dictionary.
    
    Returns all 11 ablations defined in run_ablation_full.py.
    """
    return [(ablation["name"], None) for ablation in ABLATIONS.values()]


def run_evaluation(
    adapter_path: str,
    test_experience: str,
    output_dir: Path,
    dry_run: bool = False
) -> bool:
    """
    Run a single evaluation.
    
    Args:
        adapter_path: Path to adapter (relative to cco/experiments/)
        test_experience: Experience name to test on
        output_dir: Directory to save results
        dry_run: If True, only print command without executing
    
    Returns:
        True if successful, False otherwise
    """
    cmd = [
        "cco-eval-carve",
        "--lora_adapters", adapter_path,
        "--experience_name", test_experience,
        "--hierarchical_eval",
        "--hierarchical_topk", "1",
        "--hier_domain_score_mode", "hybrid",
        "--use_router",
    ]
    
    # Use a temporary output name, we'll reorganize files after
    # The eval will save to results/{experience_name}/{adapter_path}/
    # We'll copy and rename the files after
    
    print(f"  Running: {' '.join(cmd)}")
    
    if dry_run:
        print(f"    [DRY RUN] Would execute command above")
        return True
    
    try:
        result = subprocess.run(
            cmd,
            cwd=SCRIPT_DIR,
            check=True,
            capture_output=False
        )
        return result.returncode == 0
    except subprocess.CalledProcessError as e:
        print(f"    ✗ Error: {e}")
        return False
    except Exception as e:
        print(f"    ✗ Unexpected error: {e}")
        return False


def find_eval_results(adapter_path: str, test_experience: str) -> Optional[Tuple[Path, Path]]:
    """
    Find the evaluation results files.
    
    Returns:
        Tuple of (answers_path, metrics_path) or None if not found
    """
    # cco-eval saves to results/{experience_name}/{adapter_path}/
    # where adapter_path is the first adapter in the list, with slashes replaced by the actual path structure
    # Actually, it uses the adapter path as-is, so we need to check the actual directory structure
    
    # Try different possible paths
    possible_paths = [
        # Direct path (if adapter_path is used as-is)
        SCRIPT_DIR / "results" / test_experience / adapter_path,
        # Path with slashes replaced by dashes
        SCRIPT_DIR / "results" / test_experience / adapter_path.replace("/", "-"),
        # Path with just the last component
        SCRIPT_DIR / "results" / test_experience / Path(adapter_path).name,
    ]
    
    for results_dir in possible_paths:
        answers_path = results_dir / "answers.jsonl"
        metrics_path = results_dir / "metrics.json"
        
        if answers_path.exists() and metrics_path.exists():
            return (answers_path, metrics_path)
    
    # Also try searching in the results directory
    results_base = SCRIPT_DIR / "results" / test_experience
    if results_base.exists():
        # Search for directories that might contain our results
        adapter_name = Path(adapter_path).name
        for subdir in results_base.iterdir():
            if subdir.is_dir() and (adapter_name in subdir.name or adapter_path.replace("/", "-") in subdir.name):
                answers_path = subdir / "answers.jsonl"
                metrics_path = subdir / "metrics.json"
                if answers_path.exists() and metrics_path.exists():
                    return (answers_path, metrics_path)
    
    return None


def convert_metrics_to_jsonl(metrics_path: Path, output_path: Path):
    """Convert metrics.json to metrics.jsonl format."""
    try:
        with open(metrics_path, 'r') as f:
            metrics = json.load(f)
        
        # Write as JSONL (single line)
        with open(output_path, 'w') as f:
            f.write(json.dumps(metrics) + "\n")
    except Exception as e:
        print(f"    Warning: Could not convert metrics: {e}")


def copy_and_rename_results(
    adapter_experience: str,
    test_experience: str,
    ablation_dir: Path,
    answers_path: Path,
    metrics_path: Path
):
    """Copy and rename results to the ablation directory with specified naming."""
    # Create output filenames
    answers_output = ablation_dir / f"exp:{adapter_experience}_test_set:{test_experience}_answers.jsonl"
    metrics_output = ablation_dir / f"exp:{adapter_experience}_test_set:{test_experience}_metrics.jsonl"
    
    # Copy answers
    import shutil
    shutil.copy2(answers_path, answers_output)
    
    # Convert and copy metrics
    convert_metrics_to_jsonl(metrics_path, metrics_output)
    
    print(f"    ✓ Saved: {answers_output.name}")
    print(f"    ✓ Saved: {metrics_output.name}")


def run_ablation_evaluations(
    ablation_name: str,
    seed: Optional[int] = None,
    dry_run: bool = False,
    continue_on_error: bool = False
) -> bool:
    """
    Run all evaluations for a single ablation.
    
    Args:
        ablation_name: Name of the ablation
        seed: Optional seed value (used for finding adapters and organizing results)
        dry_run: If True, only print commands without executing
        continue_on_error: If True, continue even on error
    """
    seed_str = f" (seed={seed})" if seed is not None else ""
    print(f"\n{'='*80}")
    print(f"Ablation: {ablation_name}{seed_str}")
    print(f"{'='*80}")
    
    # Check if this is joint training
    ablation = next((a for a in ABLATIONS.values() if a["name"] == ablation_name), None)
    is_joint_training = ablation and ablation.get("joint_training", False)
    
    # Find adapters for this ablation
    adapters = find_adapters_for_ablation(ablation_name, seed)
    
    print(f"Found adapters:")
    for exp, adapter in adapters.items():
        if adapter:
            print(f"  {exp}: {adapter}")
        else:
            print(f"  {exp}: (not found)")
    
    # Create ablation results directory (include seed in path if provided)
    if seed is not None:
        ablation_dir = RESULTS_BASE / f"{ablation_name}_seed{seed}"
    else:
        ablation_dir = RESULTS_BASE / ablation_name
    ablation_dir.mkdir(parents=True, exist_ok=True)
    
    # Evaluation pattern:
    # For joint training: joint adapter → test on all experiences
    # For other ablations:
    #   - apibench adapter → test on apibench
    #   - mllm adapter → test on apibench, mllm
    #   - hugging-bench-1 adapter → test on apibench, mllm, hugging-bench-1
    #   - hugging-bench-2 adapter → test on apibench, mllm, hugging-bench-1, hugging-bench-2
    
    success_count = 0
    total_count = 0
    
    if is_joint_training:
        # Joint training: use the same adapter for all test experiences
        adapter_path = adapters.get(EXPERIENCES[0])  # All experiences point to the same adapter
        if not adapter_path:
            print(f"\n  Skipping joint training (adapter not found)")
            return False
        
        # Test on all experiences
        test_experiences = EXPERIENCES
        adapter_exp = "joint"  # Use "joint" as the adapter experience name
        
        print(f"\n  Adapter: joint ({adapter_path})")
        print(f"  Testing on: {', '.join(test_experiences)}")
        
        for test_exp in test_experiences:
            total_count += 1
            print(f"\n    Evaluating on {test_exp}...")
            
            success = run_evaluation(
                adapter_path=adapter_path,
                test_experience=test_exp,
                output_dir=ablation_dir,
                dry_run=dry_run
            )
            
            if success and not dry_run:
                # Wait a moment for files to be written
                time.sleep(1)
                
                # Find and copy results
                results = find_eval_results(adapter_path, test_exp)
                if results:
                    answers_path, metrics_path = results
                    copy_and_rename_results(
                        adapter_experience=adapter_exp,
                        test_experience=test_exp,
                        ablation_dir=ablation_dir,
                        answers_path=answers_path,
                        metrics_path=metrics_path
                    )
                    success_count += 1
                else:
                    print(f"    ⚠ Results not found for {adapter_exp} on {test_exp}")
                    print(f"      Searched in: results/{test_exp}/")
            elif success:
                success_count += 1
            else:
                print(f"    ✗ Failed to evaluate {adapter_exp} on {test_exp}")
                if not continue_on_error:
                    return False
    else:
        # Standard evaluation pattern for non-joint ablations
        for i, adapter_exp in enumerate(EXPERIENCES):
            adapter_path = adapters.get(adapter_exp)
            if not adapter_path:
                print(f"\n  Skipping {adapter_exp} (adapter not found)")
                continue
            
            # Test on all experiences up to and including this one
            test_experiences = EXPERIENCES[:i+1]
            
            print(f"\n  Adapter: {adapter_exp} ({adapter_path})")
            print(f"  Testing on: {', '.join(test_experiences)}")
            
            for test_exp in test_experiences:
                total_count += 1
                print(f"\n    Evaluating on {test_exp}...")
                
                success = run_evaluation(
                    adapter_path=adapter_path,
                    test_experience=test_exp,
                    output_dir=ablation_dir,
                    dry_run=dry_run
                )
                
                if success and not dry_run:
                    # Wait a moment for files to be written
                    time.sleep(1)
                    
                    # Find and copy results
                    results = find_eval_results(adapter_path, test_exp)
                    if results:
                        answers_path, metrics_path = results
                        copy_and_rename_results(
                            adapter_experience=adapter_exp,
                            test_experience=test_exp,
                            ablation_dir=ablation_dir,
                            answers_path=answers_path,
                            metrics_path=metrics_path
                        )
                        success_count += 1
                    else:
                        print(f"    ⚠ Results not found for {adapter_exp} on {test_exp}")
                        print(f"      Searched in: results/{test_exp}/")
                elif success:
                    success_count += 1
                else:
                    print(f"    ✗ Failed to evaluate {adapter_exp} on {test_exp}")
                    if not continue_on_error:
                        return False
    
    print(f"\n  Completed: {success_count}/{total_count} evaluations")
    return success_count == total_count


def main():
    parser = argparse.ArgumentParser(
        description="Run cco-eval over all ablations",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run all ablations with default seeds (40, 41, 42)
  python run_ablation_evals.py --all
  
  # Run ablation 1 with default seeds
  python run_ablation_evals.py --ablation 1
  
  # Run ablation 1 with specific seeds
  python run_ablation_evals.py --ablation 1 --seeds 40 41 42
  
  # Run ablation 1 with single seed
  python run_ablation_evals.py --ablation 1 --seeds 40
  
  # Dry run (show commands without executing)
  python run_ablation_evals.py --ablation 1 --dry-run
  
  # Continue on error
  python run_ablation_evals.py --all --continue-on-error
        """
    )
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Show what would be executed without actually running"
    )
    parser.add_argument(
        "--continue-on-error",
        action="store_true",
        help="Continue to next ablation even if current one fails"
    )
    parser.add_argument(
        "--ablation",
        type=str,
        help="Run only a specific ablation by number (1-11) or name (e.g., 'ablate_replay_domain_05')"
    )
    parser.add_argument(
        "--all",
        action="store_true",
        help="Run all ablations (1-11)"
    )
    parser.add_argument(
        "--seeds",
        type=int,
        nargs="+",
        default=[40, 41, 42],
        metavar="SEED",
        help="Random seeds to use (default: 40 41 42). Adapters will be searched with seed suffix."
    )
    
    args = parser.parse_args()
    
    if not args.ablation and not args.all:
        parser.error("Must specify either --ablation N or --all")
    
    if args.ablation and args.all:
        parser.error("Cannot specify both --ablation and --all")
    
    print("="*80)
    print("Ablation Evaluation Runner")
    print("="*80)
    print(f"Experiments root: {EXPERIMENTS_ROOT}")
    print(f"Results base: {RESULTS_BASE}")
    print(f"Seeds: {args.seeds}")
    print(f"Dry run: {args.dry_run}")
    print(f"Continue on error: {args.continue_on_error}")
    
    # Get all ablations
    all_ablations = get_all_ablations()
    
    if not all_ablations:
        print("No ablations found!")
        return 1
    
    print(f"\nFound {len(all_ablations)} ablations:")
    for ablation_name, config_path in all_ablations:
        print(f"  - {ablation_name}")
    
    # Filter if specific ablation requested
    if args.ablation:
        # Try to parse as number first
        try:
            ablation_num = int(args.ablation)
            if ablation_num in ABLATIONS:
                ablation_name = ABLATIONS[ablation_num]["name"]
                all_ablations = [(name, path) for name, path in all_ablations if name == ablation_name]
            else:
                print(f"\nError: Invalid ablation number {ablation_num}. Must be 1-11.")
                return 1
        except ValueError:
            # Not a number, treat as name
            all_ablations = [(name, path) for name, path in all_ablations if args.ablation in name]
        
        if not all_ablations:
            print(f"\nNo ablation found matching: {args.ablation}")
            return 1
        print(f"\nRunning only: {', '.join([name for name, _ in all_ablations])}")
    
    # Run evaluations for each seed
    successful_ablations = []
    failed_ablations = []
    
    for seed in args.seeds:
        print(f"\n{'#'*80}")
        print(f"# Running ablations with seed={seed}")
        print(f"{'#'*80}\n")
        
        for ablation_name, config_path in all_ablations:
            ablation_key = f"{ablation_name}_seed{seed}"
            success = run_ablation_evaluations(
                ablation_name=ablation_name,
                seed=seed,
                dry_run=args.dry_run,
                continue_on_error=args.continue_on_error
            )
            
            if success:
                successful_ablations.append(ablation_key)
            else:
                failed_ablations.append(ablation_key)
                if not args.continue_on_error:
                    print(f"\nStopping due to error in {ablation_key}")
                    break
    
    # Summary
    print("\n" + "="*80)
    print("Summary")
    print("="*80)
    total_expected = len(all_ablations) * len(args.seeds)
    print(f"Total ablations (across all seeds): {total_expected}")
    print(f"Successful: {len(successful_ablations)}")
    if successful_ablations:
        print(f"  {successful_ablations[:10]}..." if len(successful_ablations) > 10 else f"  {successful_ablations}")
    if failed_ablations:
        print(f"Failed: {len(failed_ablations)}")
        print(f"  {failed_ablations[:10]}..." if len(failed_ablations) > 10 else f"  {failed_ablations}")
    print(f"\nResults saved to: {RESULTS_BASE}")
    
    return 0 if len(failed_ablations) == 0 else 1


if __name__ == "__main__":
    sys.exit(main())

