#!/usr/bin/env python3
"""
Run a full ablation end-to-end across all experiences.

This script runs a single ablation (1-10) or all ablations across the full
experience sequence: apibench -> mllm -> hugging-bench-1 -> hugging-bench-2

For exp1 (apibench), uses configurations/apibench.yaml
For exp2+ (mllm, hugging-bench-1, hugging-bench-2), uses configurations/ablations_mllm_icml2026/mllm_onwards.yaml

The variant names are set consistently so adapters can be correctly found for resuming.
router_registry_base_path is inferred from lora_adapters when not explicitly set.
"""

import subprocess
import sys
import os
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Optional, Tuple
import argparse
import yaml

# Constants
SCRIPT_DIR = Path(__file__).parent.absolute()
EXPERIMENTS_ROOT = SCRIPT_DIR / "cco" / "experiments"
LOG_DIR = SCRIPT_DIR / "batch_training_logs"
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M%S")

# Create log directory
LOG_DIR.mkdir(exist_ok=True)

# Experience sequence
EXPERIENCES = ["apibench", "mllm", "hugging-bench-1", "hugging-bench-2"]

# Base configs
EXP1_CONFIG = "configurations_carve/apibench.yaml"
EXP2_ONWARDS_CONFIG = "configurations_carve/ablations_mllm_icml2026/mllm_onwards.yaml"
JOINT_TRAINING_CONFIG = "configurations_carve/joint_training_config.yaml"

# Base variant name (consistent across all experiences)
BASE_VARIANT_NAME = "router_lasttok_dim1024_soft05_k10"

# Ablation definitions (1-8)
ABLATIONS = {
    1: {
        "name": "ablate_replay_domain_05",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.05,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    2: {
        "name": "ablate_replay_domain_10",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    3: {
        "name": "ablate_replay_domain_20",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.20,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    4: {
        "name": "ablate_replay_random_05",
        "replay_strategy": "random",
        "replay_percentage": 0.05,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    5: {
        "name": "ablate_replay_random_10",
        "replay_strategy": "random",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    6: {
        "name": "ablate_replay_random_20",
        "replay_strategy": "random",
        "replay_percentage": 0.20,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
    },
    7: {
        "name": "ablate_no_emb_anchor",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": False,
        "proj_anchor": True,
    },
    8: {
        "name": "ablate_no_proj_anchor",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": False,
    },
        9: {
        "name": "ablate_cumulative_training",
        "replay_strategy": "random",
        "replay_percentage": 1.00,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
        "cumulative": True  # Special flag for cumulative training
    },
    10: {
        "name": "ablate_from_scratch_training",
        "replay_strategy": "random",
        "replay_percentage": 0.00,
        "two_phase": False,
        "emb_anchor": False,
        "proj_anchor": False,
        "from_scratch": True  # Special flag for from-scratch training (separate adapters)
    },
    11: {
        "name": "ablate_joint_training",
        "replay_strategy": "random",
        "replay_percentage": None,
        "two_phase": False,
        "emb_anchor": False,
        "proj_anchor": False,
        "joint_training": True  # Special flag for normal joint training (all experiences at once)
    },
    12: {
        "name": "ablate_replay_domain_10_alt_repo",
        "replay_strategy": "domain_model_coreset",
        "replay_percentage": 0.10,
        "two_phase": True,
        "emb_anchor": True,
        "proj_anchor": True,
        "repo_id": "Qwen/Qwen2.5-7B-Instruct",  # TODO: Confirm exact repo_id
    },
}


def find_latest_checkpoint(exp_dir: Path) -> Optional[str]:
    """Find the latest checkpoint in an experience directory."""
    if not exp_dir.exists():
        return None
    
    checkpoints = []
    for item in exp_dir.iterdir():
        if item.is_dir() and item.name.startswith("checkpoint-"):
            try:
                checkpoint_num = int(item.name.split("-")[1])
                checkpoints.append((checkpoint_num, item.name))
            except (ValueError, IndexError):
                continue
    
    if not checkpoints:
        return None
    
    checkpoints.sort(key=lambda x: x[0], reverse=True)
    return checkpoints[0][1]


def get_adapter_path(experience_name: str, variant_name: str, extra_info: str = "", seed: Optional[int] = None) -> Optional[str]:
    """
    Get the adapter path for a given experience, variant, and extra_info.
    
    Args:
        experience_name: Name of the experience
        variant_name: Base variant name (will have seed appended if provided)
        extra_info: Extra info to append
        seed: Optional seed value to include in variant name
    
    Returns:
        Adapter path relative to cco/experiments/ (e.g., "apibench-ablate_replay_domain_05_seed40/checkpoint-310")
        or None if not found
    """
    # Construct variant name with seed if provided
    # Note: variant_name should NOT include experience_name - we construct it here to match trainer's format
    # The trainer creates: {experience_name}-{variant_name} (e.g., "apibench-ablate_replay_domain_05_seed40")
    full_variant_name = variant_name
    if seed is not None:
        full_variant_name = f"{variant_name}_seed{seed}"
    
    # Construct experiment directory name to match what the trainer creates
    # Format: {experience_name}-{variant_name} or {experience_name}-{variant_name}-{extra_info} if extra_info is provided
    exp_dir_name = f"{experience_name}-{full_variant_name}"
    if extra_info:
        exp_dir_name += f"-{extra_info}"
    
    exp_dir = EXPERIMENTS_ROOT / exp_dir_name
    
    if not exp_dir.exists():
        return None
    
    checkpoint = find_latest_checkpoint(exp_dir)
    if not checkpoint:
        return None
    
    return f"{exp_dir_name}/{checkpoint}"


def infer_router_registry_base_path(lora_adapter: str) -> str:
    """
    Infer router_registry_base_path from lora_adapter path.
    
    Args:
        lora_adapter: Adapter path relative to cco/experiments/ (e.g., "apibench-router_lasttok_dim1024_soft05_k10-ablate_mllm_replay_domain_05/checkpoint-310")
    
    Returns:
        Path to model_registry.json relative to project root (e.g., "cco/experiments/apibench-router_lasttok_dim1024_soft05_k10-ablate_mllm_replay_domain_05/checkpoint-310/model_registry.json")
    """
    return f"cco/experiments/{lora_adapter}/model_registry.json"


def create_temp_config(
    experience_name: str,
    exp_idx: int,
    ablation: Dict,
    previous_adapter: Optional[str] = None,
    temp_config_dir: Path = None,
    seed: Optional[int] = None
) -> Path:
    """
    Create a temporary config file with the correct overrides for this experience and ablation.
    
    Args:
        experience_name: Name of the experience
        exp_idx: Index of experience in sequence
        ablation: Ablation configuration dictionary
        previous_adapter: Adapter path from previous experience
        temp_config_dir: Directory for temp config files
        seed: Optional seed value to include in variant name and set in config
    
    Returns:
        Path to temporary config file
    """
    # Choose base config based on ablation type and experience index
    if ablation.get("joint_training", False):
        # Joint training: use joint_training_config.yaml
        base_config_path = SCRIPT_DIR / JOINT_TRAINING_CONFIG
    elif ablation.get("from_scratch", False) and exp_idx > 0:
        # From-scratch training (ablation 10) uses joint training mode for exp_idx > 0
        # Use joint_training_config.yaml as base
        base_config_path = SCRIPT_DIR / JOINT_TRAINING_CONFIG
    elif exp_idx == 0:
        # Exp1: use apibench.yaml
        base_config_path = SCRIPT_DIR / EXP1_CONFIG
    else:
        # Exp2+: use mllm_onwards.yaml
        base_config_path = SCRIPT_DIR / EXP2_ONWARDS_CONFIG
    
    # Load base config
    with open(base_config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Apply overrides
    # Use ablation name as variant name, with seed appended if provided
    # Note: variant_name should NOT include experience_name - the trainer prepends it automatically
    # The trainer creates: {experience_name}-{variant_name} (e.g., "apibench-ablate_replay_domain_05_seed40")
    variant_name = ablation["name"]
    if seed is not None:
        variant_name = f"{variant_name}_seed{seed}"
    config["variant_name"] = variant_name
    config["extra_info"] = ""  # Leave extra_info empty since variant_name contains the ablation info
    
    # Set seed in config if provided
    if seed is not None:
        config["seed"] = seed
    
    # Override repo_id if specified in ablation
    if "repo_id" in ablation:
        config["repo_id"] = ablation["repo_id"]
    
    # For exp1, ensure experience_name is set correctly and clear any ablation-related settings
    if exp_idx == 0:
        config["experience_name"] = experience_name
        # Remove experiences_sequence if present (exp1 is single experience)
        if "experiences_sequence" in config:
            del config["experiences_sequence"]
        
        # Clear ablation-related settings that should not apply to exp1
        # (exp1 doesn't use replay, two-phase, or anchoring since there's nothing to anchor to)
        if "replay_percentage" in config:
            del config["replay_percentage"]
        if "replay_strategy" in config:
            del config["replay_strategy"]
        if "router_two_phase_enable" in config:
            config["router_two_phase_enable"] = False
        if "router_anchor_enable" in config:
            config["router_anchor_enable"] = False
        if "router_proj_anchor_enable" in config:
            config["router_proj_anchor_enable"] = False
        if "lora_adapters" in config:
            config["lora_adapters"] = []
        if "router_registry_base_path" in config:
            config["router_registry_base_path"] = None
        if "router_registry_init_mode" in config:
            config["router_registry_init_mode"] = "fresh"  # Exp1 builds registry from scratch
    
    # Special handling for cumulative training (ablation 9)
    if ablation.get("cumulative", False):
        # For cumulative training, use joint training mode but with cumulative experiences
        # Step 1: Train on apibench only (exp_idx=0)
        # Step 2: Train on apibench + mllm together (exp_idx=1)
        # Step 3: Train on apibench + mllm + hugging-bench-1 together (exp_idx=2)
        # Step 4: Train on apibench + mllm + hugging-bench-1 + hugging-bench-2 together (exp_idx=3)
        if exp_idx == 0:
            # Step 1: Single experience (apibench only)
            config["experience_name"] = EXPERIENCES[0]
            # Remove experiences_sequence if present (exp1 is single experience)
            if "experiences_sequence" in config:
                del config["experiences_sequence"]
            # Clear ablation-related settings that should not apply to exp1
            if "replay_percentage" in config:
                del config["replay_percentage"]
            if "replay_strategy" in config:
                del config["replay_strategy"]
            if "router_two_phase_enable" in config:
                config["router_two_phase_enable"] = False
            if "router_anchor_enable" in config:
                config["router_anchor_enable"] = False
            if "router_proj_anchor_enable" in config:
                config["router_proj_anchor_enable"] = False
            if "lora_adapters" in config:
                config["lora_adapters"] = []
            if "router_registry_base_path" in config:
                config["router_registry_base_path"] = None
            if "router_registry_init_mode" in config:
                config["router_registry_init_mode"] = "fresh"
        else:
            # Steps 2-4: Joint training with cumulative experiences
            # Use joint training mode but apply ablation settings for two-phase, anchors, and replay
            config["joint_training"] = True
            config["experiences_sequence"] = EXPERIENCES[:exp_idx + 1]
            # Set experience_name to the last experience for output directory naming
            config["experience_name"] = EXPERIENCES[exp_idx]
            
            # Set lora_adapters to previous experience's adapter
            if previous_adapter:
                config["lora_adapters"] = [previous_adapter]
                # Infer router_registry_base_path from lora_adapters
                registry_path = infer_router_registry_base_path(previous_adapter)
                config["router_registry_base_path"] = registry_path
            
            # Apply ablation settings for replay, two-phase, and anchors
            # Note: replay settings may not be used in joint training, but we set them anyway
            config["replay_percentage"] = None
            config["replay_strategy"] = None
            config["router_two_phase_enable"] = ablation["two_phase"]
            config["router_anchor_enable"] = ablation["emb_anchor"]
            config["router_proj_anchor_enable"] = ablation["proj_anchor"]
    
    # Special handling for joint training (ablation 11)
    elif ablation.get("joint_training", False):
        # For joint training, train on all experiences at once in a single run
        # This is different from cumulative training which trains incrementally
        if exp_idx == 0:
            # Only run once (on exp_idx=0), training on all experiences together
            config["joint_training"] = True
            config["experiences_sequence"] = EXPERIENCES  # All 4 experiences
            config["experience_name"] = "joint"  # Use "joint" as experience name for output directory
            
            # No previous adapter (fresh start)
            config["lora_adapters"] = []
            config["router_registry_base_path"] = None
            if "router_registry_init_mode" in config:
                config["router_registry_init_mode"] = "fresh"
            
            # Disable replay, two-phase, and anchors (not needed in joint training)
            config["replay_percentage"] = None
            config["replay_strategy"] = None
            config["router_two_phase_enable"] = False
            config["router_anchor_enable"] = False
            config["router_proj_anchor_enable"] = False
            
            # Set loss mode to supervised+router and unfreeze LM for joint training
            config["loss_mode"] = "supervised+router"
            config["router_freeze_lm"] = False
        else:
            # Should not reach here - joint training only runs once
            raise ValueError("Joint training should only run on exp_idx=0")
    
    # Special handling for from-scratch training (ablation 10)
    elif ablation.get("from_scratch", False):
        # For from-scratch training, use joint training mode but start fresh each time
        # Step 1: Train on apibench only (exp_idx=0)
        # Step 2: Train on apibench + mllm together (exp_idx=1)
        # Step 3: Train on apibench + mllm + hugging-bench-1 together (exp_idx=2)
        if exp_idx == 0:
            # Step 1: Single experience (apibench only)
            config["experience_name"] = EXPERIENCES[0]
            # Remove experiences_sequence if present (exp1 is single experience)
            if "experiences_sequence" in config:
                del config["experiences_sequence"]
            # Clear ablation-related settings that should not apply to exp1
            if "replay_percentage" in config:
                del config["replay_percentage"]
            if "replay_strategy" in config:
                del config["replay_strategy"]
            if "router_two_phase_enable" in config:
                config["router_two_phase_enable"] = False
            if "router_anchor_enable" in config:
                config["router_anchor_enable"] = False
            if "router_proj_anchor_enable" in config:
                config["router_proj_anchor_enable"] = False
            if "lora_adapters" in config:
                config["lora_adapters"] = []
            if "router_registry_base_path" in config:
                config["router_registry_base_path"] = None
            if "router_registry_init_mode" in config:
                config["router_registry_init_mode"] = "fresh"
        else:
            # Steps 2-3: Joint training with cumulative experiences, but from scratch
            config["joint_training"] = True
            config["experiences_sequence"] = EXPERIENCES[:exp_idx + 1]
            # Set experience_name to the last experience for output directory naming
            config["experience_name"] = EXPERIENCES[exp_idx]
            
            # Start from scratch - no previous adapter or registry
            config["lora_adapters"] = []
            config["resume_from"] = None
            config["router_registry_base_path"] = None
            config["router_registry_init_mode"] = "fresh"
            
            # Set loss mode to supervised+router and unfreeze LM for from-scratch training
            config["loss_mode"] = "supervised+router"
            config["router_freeze_lm"] = False
            
            # Apply ablation settings for replay, two-phase, and anchors
            # Note: replay settings may not be used in joint training, but we set them anyway
            config["replay_percentage"] = None
            config["replay_strategy"] = None
            config["router_two_phase_enable"] = ablation["two_phase"]
            config["router_anchor_enable"] = ablation["emb_anchor"]
            config["router_proj_anchor_enable"] = ablation["proj_anchor"]
    
    # For exp2+ (non-cumulative, non-from-scratch ablations), set up experience sequence and other settings
    elif exp_idx > 0:
        # Build experience sequence up to current experience
        config["experiences_sequence"] = EXPERIENCES[:exp_idx + 1]
        
        # Set lora_adapters to previous experience's adapter
        if previous_adapter:
            config["lora_adapters"] = [previous_adapter]
            
            # Infer router_registry_base_path from lora_adapters
            registry_path = infer_router_registry_base_path(previous_adapter)
            config["router_registry_base_path"] = registry_path
        
        # Set replay percentage and strategy from ablation config
        config["replay_percentage"] = ablation["replay_percentage"]
        config["replay_strategy"] = ablation["replay_strategy"]
        
        # Set two-phase, embedding anchor, and projection anchor
        config["router_two_phase_enable"] = ablation["two_phase"]
        config["router_anchor_enable"] = ablation["emb_anchor"]
        config["router_proj_anchor_enable"] = ablation["proj_anchor"]
    
    # Create temp config file
    if temp_config_dir is None:
        temp_config_dir = SCRIPT_DIR / "configurations" / "temp_ablation_configs"
    temp_config_dir.mkdir(parents=True, exist_ok=True)
    
    # Include seed in temp config filename if provided
    temp_filename = f"ablation_{ablation['name']}_{experience_name}"
    if seed is not None:
        temp_filename += f"_seed{seed}"
    temp_config_path = temp_config_dir / f"{temp_filename}.yaml"
    
    # Save temp config
    with open(temp_config_path, 'w') as f:
        yaml.dump(config, f, default_flow_style=False, sort_keys=False)
    
    return temp_config_path


def build_training_command(
    experience_name: str,
    exp_idx: int,
    ablation: Dict,
    previous_adapter: Optional[str] = None,
    dry_run: bool = False,
    seed: Optional[int] = None
) -> Tuple[List[str], Optional[Path]]:
    """
    Build the training command for a single experience.
    
    Args:
        experience_name: Name of the experience (apibench, mllm, etc.)
        exp_idx: Index of experience in sequence (0-based)
        ablation: Ablation configuration dictionary
        previous_adapter: Adapter path from previous experience (for exp2+)
        dry_run: If True, only print command without executing
        seed: Optional seed value to include in variant name and set in config
    
    Returns:
        Tuple of (command arguments, temp_config_path)
    """
    # Create temporary config file with all overrides
    temp_config_path = create_temp_config(
        experience_name=experience_name,
        exp_idx=exp_idx,
        ablation=ablation,
        previous_adapter=previous_adapter,
        seed=seed
    )
    
    cmd = ["cco-train-carve", "--config", str(temp_config_path)]
    
    # For exp2+, add skip_training_experiences CLI argument
    # This tells the trainer to skip training on previous experiences but load their data for replay
    # EXCEPT for cumulative training (ablation 9), from-scratch training (ablation 10), and joint training (ablation 11), which use joint training mode
    if exp_idx > 0:
        # For cumulative, from-scratch, and joint training, we use joint training mode so no skip needed
        if not ablation.get("cumulative", False) and not ablation.get("from_scratch", False) and not ablation.get("joint_training", False):
            # Skip all previous experiences in the sequence
            skip_experiences = EXPERIENCES[:exp_idx]
            if skip_experiences:
                cmd.extend(["--skip_training_experiences"] + skip_experiences)
    
    return cmd, temp_config_path


def run_training(
    experience_name: str,
    exp_idx: int,
    ablation: Dict,
    previous_adapter: Optional[str] = None,
    dry_run: bool = False,
    continue_on_error: bool = False,
    seed: Optional[int] = None
) -> Tuple[bool, Optional[str]]:
    """
    Run training for a single experience.
    
    Args:
        experience_name: Name of the experience
        exp_idx: Index of experience in sequence
        ablation: Ablation configuration dictionary
        previous_adapter: Adapter path from previous experience
        dry_run: If True, only print command without executing
        continue_on_error: If True, continue even on error
        seed: Optional seed value to include in variant name and set in config
    
    Returns:
        Tuple of (success: bool, adapter_path: Optional[str])
    """
    seed_str = f" (seed={seed})" if seed is not None else ""
    print(f"\n{'='*80}")
    if ablation.get("joint_training", False):
        print(f"Joint Training: All {len(EXPERIENCES)} experiences together{seed_str}")
    else:
        print(f"Experience {exp_idx + 1}/{len(EXPERIENCES)}: {experience_name}{seed_str}")
    print(f"{'='*80}")
    
    cmd, temp_config_path = build_training_command(experience_name, exp_idx, ablation, previous_adapter, dry_run, seed)
    
    print(f"Command: {' '.join(cmd)}")
    print(f"Temp config: {temp_config_path}")
    
    if dry_run:
        print("[DRY RUN] Would execute command above")
        # For dry run, return a mock adapter path
        variant_name = ablation['name']
        if seed is not None:
            variant_name = f"{variant_name}_seed{seed}"
        exp_dir_name = f"{experience_name}-{variant_name}"
        return (True, f"{exp_dir_name}/checkpoint-XXX")
    
    # Create log file
    # For joint training, use "joint" as experience name for log filename
    log_experience_name = "joint" if ablation.get("joint_training", False) else experience_name
    log_filename = f"ablation_{ablation['name']}_{log_experience_name}"
    if seed is not None:
        log_filename += f"_seed{seed}"
    log_file = LOG_DIR / f"{log_filename}_{TIMESTAMP}.log"
    
    print(f"Logging to: {log_file}")
    
    try:
        with open(log_file, 'w') as f:
            result = subprocess.run(
                cmd,
                cwd=SCRIPT_DIR,
                check=False,
                stdout=f,
                stderr=subprocess.STDOUT
            )
        
        # Clean up temp config file
        try:
            if temp_config_path.exists():
                temp_config_path.unlink()
        except Exception:
            pass  # Ignore cleanup errors
        
        if result.returncode == 0:
            # Find the adapter path (variant_name is the ablation name, extra_info is empty)
            # For cumulative, from-scratch, and joint training, the adapter is saved under a specific experience
            if ablation.get("joint_training", False):
                # For joint training, adapter is saved under "joint" experience name
                lookup_experience = "joint"
            elif ablation.get("cumulative", False) or ablation.get("from_scratch", False):
                # For cumulative/from-scratch training, use the last experience in the sequence
                # Step 1 uses apibench, steps 2+ use the last experience in their sequence
                lookup_experience = EXPERIENCES[exp_idx]
            else:
                lookup_experience = experience_name
            adapter_path = get_adapter_path(lookup_experience, ablation["name"], "", seed)
            if adapter_path:
                print(f"✓ Training completed successfully")
                print(f"  Adapter: {adapter_path}")
                return (True, adapter_path)
            else:
                print(f"⚠ Training completed but adapter not found")
                return (True, None)
        else:
            print(f"✗ Training failed with return code {result.returncode}")
            print(f"  Check log: {log_file}")
            if not continue_on_error:
                return (False, None)
            return (False, None)
    
    except Exception as e:
        print(f"✗ Error running training: {e}")
        # Clean up temp config file on error
        try:
            if temp_config_path.exists():
                temp_config_path.unlink()
        except Exception:
            pass
        if not continue_on_error:
            return (False, None)
        return (False, None)


def run_ablation(
    ablation_num: int,
    dry_run: bool = False,
    continue_on_error: bool = False,
    resume_from: Optional[str] = None,
    seed: Optional[int] = None
) -> bool:
    """
    Run a single ablation across all experiences.
    
    Args:
        ablation_num: Ablation number (1-10)
        dry_run: If True, only print commands without executing
        continue_on_error: If True, continue to next experience even if current fails
        resume_from: Experience name to resume from (e.g., "mllm"). If None, starts from beginning.
        seed: Optional seed value to include in variant name and set in config
    
    Returns:
        True if all experiences completed successfully, False otherwise
    """
    if ablation_num not in ABLATIONS:
        print(f"Error: Invalid ablation number {ablation_num}. Must be 1-11.")
        return False
    
    ablation = ABLATIONS[ablation_num]
    
    # Joint training only runs once, so resuming doesn't make sense
    if ablation.get("joint_training", False) and resume_from:
        print(f"Error: Joint training (ablation {ablation_num}) only runs once and cannot be resumed from a later experience.")
        return False
    
    # Determine starting experience index
    start_idx = 0
    if resume_from:
        if resume_from not in EXPERIENCES:
            print(f"Error: Invalid experience name '{resume_from}'. Must be one of: {', '.join(EXPERIENCES)}")
            return False
        start_idx = EXPERIENCES.index(resume_from)
        print(f"\n{'='*80}")
        print(f"Resuming from experience: {resume_from} (index {start_idx})")
        print(f"{'='*80}\n")
    
    print(f"\n{'='*80}")
    print(f"Running Ablation {ablation_num}: {ablation['name']}")
    print(f"{'='*80}")
    if ablation.get("joint_training", False):
        print(f"Joint Training: All experiences at once")
    else:
        print(f"Replay Strategy: {ablation['replay_strategy']}")
        print(f"Replay Percentage: {ablation['replay_percentage']}")
        print(f"Two-Phase: {ablation['two_phase']}")
        print(f"Embedding Anchor: {ablation['emb_anchor']}")
        print(f"Projection Anchor: {ablation['proj_anchor']}")
    if resume_from:
        print(f"Resuming from: {resume_from}")
    print(f"{'='*80}\n")
    
    # If resuming from a later experience, find the previous adapter
    previous_adapter = None
    if start_idx > 0:
        # Find adapter from the previous experience
        # For cumulative training, the adapter is saved under the last experience in the previous cumulative sequence
        if ablation.get("cumulative", False):
            # For cumulative training, previous adapter is saved under the last experience of the previous step
            prev_experience = EXPERIENCES[start_idx - 1]
        else:
            prev_experience = EXPERIENCES[start_idx - 1]
        previous_adapter = get_adapter_path(prev_experience, ablation["name"], "", seed)
        
        if previous_adapter:
            print(f"Found previous adapter from {prev_experience}: {previous_adapter}")
        else:
            print(f"Warning: Could not find adapter for previous experience '{prev_experience}'")
            print(f"  This may cause issues when resuming from '{resume_from}'")
            if not continue_on_error:
                print(f"  Aborting. Use --continue-on-error to proceed anyway.")
                return False
            else:
                print(f"  Continuing anyway due to --continue-on-error flag.")
    
    # Determine number of steps to run
    # Joint training (ablation 11) runs 1 step (all experiences at once)
    # From-scratch training (ablation 10) runs 3 steps, others run 4
    if ablation.get("joint_training", False):
        max_steps = 1
    elif ablation.get("from_scratch", False):
        max_steps = 3
    else:
        max_steps = len(EXPERIENCES)
    
    # Start from the resume point (or beginning if not resuming)
    for exp_idx in range(start_idx, max_steps):
        experience_name = EXPERIENCES[exp_idx]
        
        # For from-scratch and joint training, always start fresh (no previous adapter)
        adapter_to_use = None if (ablation.get("from_scratch", False) or ablation.get("joint_training", False)) else previous_adapter
        
        success, adapter_path = run_training(
            experience_name=experience_name,
            exp_idx=exp_idx,
            ablation=ablation,
            previous_adapter=adapter_to_use,
            dry_run=dry_run,
            continue_on_error=continue_on_error,
            seed=seed
        )
        
        if not success:
            print(f"\n✗ Ablation {ablation_num} failed at experience {exp_idx + 1}: {experience_name}")
            if not continue_on_error:
                return False
        
        # Update previous_adapter for next experience (only for non-from-scratch and non-joint)
        if not ablation.get("from_scratch", False) and not ablation.get("joint_training", False):
            if adapter_path:
                previous_adapter = adapter_path
            elif not dry_run:
                # If we don't have an adapter path and it's not a dry run, we can't continue
                print(f"\n✗ Cannot continue: no adapter path for {experience_name}")
                if not continue_on_error:
                    return False
    
    seed_str = f" (seed={seed})" if seed is not None else ""
    print(f"\n{'='*80}")
    print(f"✓ Ablation {ablation_num} ({ablation['name']}){seed_str} completed successfully")
    print(f"{'='*80}\n")
    
    return True


def main():
    parser = argparse.ArgumentParser(
        description="Run ablations end-to-end across all experiences",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Run ablation 1 with default seeds (40, 41, 42)
  python run_ablation_full.py --ablation 1
  
  # Run ablation 1 with specific seeds
  python run_ablation_full.py --ablation 1 --seeds 40 41 42
  
  # Run ablation 1 with single seed
  python run_ablation_full.py --ablation 1 --seeds 40
  
  # Run all ablations with default seeds
  python run_ablation_full.py --all
  
  # Dry run (show commands without executing)
  python run_ablation_full.py --ablation 1 --dry-run
  
  # Continue on error
  python run_ablation_full.py --all --continue-on-error
  
  # Resume from a specific experience (e.g., mllm)
  python run_ablation_full.py --ablation 1 --resume-from mllm
        """
    )
    
    parser.add_argument(
        "--ablation",
        type=int,
        choices=range(1, 13),
        metavar="N",
        help="Run a specific ablation (1-12)"
    )
    
    parser.add_argument(
        "--all",
        action="store_true",
        help="Run all ablations (1-11)"
    )
    
    parser.add_argument(
        "--seeds",
        type=int,
        nargs="+",
        default=[40, 41, 42],
        metavar="SEED",
        help="Random seeds to use (default: 40 41 42). Variant names will include seed suffix."
    )
    
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Show commands without executing"
    )
    
    parser.add_argument(
        "--continue-on-error",
        action="store_true",
        help="Continue to next experience/ablation even if current one fails"
    )
    
    parser.add_argument(
        "--resume-from",
        type=str,
        choices=EXPERIENCES,
        metavar="EXPERIENCE",
        help=f"Resume from a specific experience. Must be one of: {', '.join(EXPERIENCES)}"
    )
    
    args = parser.parse_args()
    
    if not args.ablation and not args.all:
        parser.error("Must specify either --ablation N or --all")
    
    if args.ablation and args.all:
        parser.error("Cannot specify both --ablation and --all")
    
    print("="*80)
    print("Ablation Full Run Script")
    print("="*80)
    print(f"Experiments root: {EXPERIMENTS_ROOT}")
    print(f"Log directory: {LOG_DIR}")
    print(f"Seeds: {args.seeds}")
    print(f"Dry run: {args.dry_run}")
    print(f"Continue on error: {args.continue_on_error}")
    if args.resume_from:
        print(f"Resume from: {args.resume_from}")
    print()
    
    if args.all:
        # Run all ablations for each seed
        print(f"Running all ablations (1-11) for seeds: {args.seeds}...")
        print()
        
        successful_ablations = []
        failed_ablations = []
        
        for seed in args.seeds:
            print(f"\n{'#'*80}")
            print(f"# Running all ablations with seed={seed}")
            print(f"{'#'*80}\n")
            
            for ablation_num in range(1, 12):
                ablation_key = f"{ablation_num}_seed{seed}"
                success = run_ablation(
                    ablation_num=ablation_num,
                    dry_run=args.dry_run,
                    continue_on_error=args.continue_on_error,
                    resume_from=args.resume_from,
                    seed=seed
                )
                
                if success:
                    successful_ablations.append(ablation_key)
                else:
                    failed_ablations.append(ablation_key)
                    if not args.continue_on_error:
                        print(f"\nStopping due to error in ablation {ablation_num} (seed={seed})")
                        break
        
        # Summary
        print("\n" + "="*80)
        print("Summary")
        print("="*80)
        total_expected = len(ABLATIONS) * len(args.seeds)
        print(f"Total ablations (across all seeds): {total_expected}")
        print(f"Successful: {len(successful_ablations)}")
        if successful_ablations:
            print(f"  {successful_ablations[:10]}..." if len(successful_ablations) > 10 else f"  {successful_ablations}")
        if failed_ablations:
            print(f"Failed: {len(failed_ablations)}")
            print(f"  {failed_ablations[:10]}..." if len(failed_ablations) > 10 else f"  {failed_ablations}")
        print("="*80)
        
        return 0 if len(failed_ablations) == 0 else 1
    
    else:
        # Run single ablation for each seed
        all_success = True
        for seed in args.seeds:
            print(f"\n{'#'*80}")
            print(f"# Running ablation {args.ablation} with seed={seed}")
            print(f"{'#'*80}\n")
            
            success = run_ablation(
                ablation_num=args.ablation,
                dry_run=args.dry_run,
                continue_on_error=args.continue_on_error,
                resume_from=args.resume_from,
                seed=seed
            )
            
            if not success:
                all_success = False
                if not args.continue_on_error:
                    print(f"\nStopping due to error in ablation {args.ablation} (seed={seed})")
                    break
        
        return 0 if all_success else 1


if __name__ == "__main__":
    sys.exit(main())
