#!/usr/bin/env python3

import os
import subprocess
import sys
import time
from datetime import datetime

# Model names and rhyme family pairs from the original script
MODEL_NAMES = [
    "Gemma2_2B",
    "Gemma2_9B",
    "Gemma2_27B",
    "Gemma3_1B",
    "Gemma3_4B",
    "Gemma3_12B",
    "Gemma3_27B",
    "Llama3.2_3B",
    "Llama3.1_8B",
    "Qwen3_8B",
    "Qwen3_14B",
    "Qwen3_32B",
    "Gemma2_2B_Base",
    "Gemma2_9B_Base",
    "Gemma2_27B_Base",
    "Gemma3_1B_Base",
    "Gemma3_4B_Base",
    "Gemma3_12B_Base",
    "Gemma3_27B_Base",
    "Llama3.2_3B_Base",
    "Llama3.1_8B_Base",
    "Qwen3_8B_Base",
    "Qwen3_14B_Base",
    "Qwen3_32B_Base",
    "Llama3.3_70B",
    "Llama3.3_70B_Base",
]

RHYME_FAMILY_PAIRS = [
    ("ing", "oat"),
    ("ee", "ow"),
    ("oat", "ake"),
    ("ing", "ake"),
    ("oat", "ight"),
    ("ird", "it"),
    ("air", "ip"),
    ("ip", "it"),
    ("ow", "it"),
    ("ing", "it"),
    ("oat", "it"),
    ("air", "it"),
    ("ake", "ow"),
    ("oat", "ow"),
    ("ip", "ow"),
    ("air", "oat"),
    ("air", "ird"),
    ("oat", "ee"),
    ("ee", "ake"),
    ("ird", "ee"),
]

RHYME_FAMILY_PAIRS = [
    ("ing", "air"),
    ("ing", "ip"),
    ("air", "ip"),
    ("air", "oat"),
    ("ip", "oat"),
    ("ip", "ird"),
    ("oat", "ird"),
    ("oat", "ee"),
    ("ird", "ee"),
    ("ird", "ight"),
    ("ee", "ight"),
    ("ee", "ake"),
    ("ight", "ake"),
    ("ight", "ow"),
    ("ake", "ow"),
    ("ake", "it"),
    ("ow", "it"),
    ("ow", "ing"),
    ("it", "ing"),
    ("it", "air"),
]

# MODEL_NAMES = ["Llama3.2_3B", "Gemma2_9B", "Qwen3_8B"]


def run_stage(
    stage_script,
    model_name,
    output_dir="results",
    mode="rhyme_family_steering",
    strip=False,
    LAYER_FRACTION=0.8,
    rhyme_family1=None,
    rhyme_family2=None,
):
    """Run a single stage script."""
    script_path = os.path.join(os.path.dirname(__file__), stage_script)

    cmd = [
        sys.executable,
        script_path,
        "--model_name",
        model_name,
        "--output_dir",
        output_dir,
        "--mode",
        mode,
        "--LAYER_FRACTION",
        str(LAYER_FRACTION),
    ]

    if strip:
        cmd.append("--strip")

    if rhyme_family1 is not None:
        cmd += [
            "--rhyme_family1",
            rhyme_family1,
        ]
    if rhyme_family2 is not None:
        cmd += [
            "--rhyme_family2",
            rhyme_family2,
        ]

    print(f"Running: {' '.join(cmd)}")
    start_time = time.time()

    try:
        # Run without capturing output so we see it in real-time
        result = subprocess.run(cmd, check=True)
        elapsed = time.time() - start_time
        print(f"✓ Completed in {elapsed:.1f}s")
        return True
    except subprocess.CalledProcessError as e:
        elapsed = time.time() - start_time
        print(f"✗ Failed after {elapsed:.1f}s")
        print(f"Error: {e}")
        return False


def check_stage_completed(
    stage_name, model_name, output_dir="results", mode="rhyme_family_steering"
):
    """Check if a stage has already been completed."""
    exp_dir = os.path.join(output_dir, mode, model_name)

    stage_files = {
        "best_layer_and_token_pos": "best_layer_and_token_pos.json",
        "line_generation": "generated_lines.json",
        "standard_metrics": "standard_metrics.json",
        "prob_based_metrics": "prob_based_metrics.json",
        "combination": "combined_results.json",
    }

    if stage_name not in stage_files:
        return False

    file_path = os.path.join(exp_dir, stage_files[stage_name])
    return os.path.exists(file_path)


def run_experiment(
    model_name,
    output_dir="results",
    resume=True,
    mode="rhyme_family_steering",
    strip=False,
    LAYER_FRACTION=0.8,
    rhyme_family1=None,
    rhyme_family2=None,
):
    """Run all stages for a single experiment."""
    print(f"\n{'=' * 60}")
    print(f"EXPERIMENT: {model_name}")
    print(f"{'=' * 60}")

    stages = [
        ("best_layer_and_token_pos", "find_best_layer_and_token_pos.py"),
        ("line_generation", "stage_line_generation.py"),
        ("standard_metrics", "stage_standard_metrics.py"),
        ("prob_based_metrics", "stage_prob_based_metrics.py"),
        ("combination", "stage_combination.py"),
    ]

    for stage_name, stage_script in stages:
        print(f"\n--- Stage: {stage_name} ---")

        if resume and check_stage_completed(stage_name, model_name, output_dir):
            print(f"✓ Stage {stage_name} already completed, skipping")
            continue

        success = run_stage(
            stage_script,
            model_name,
            output_dir,
            mode,
            strip,
            LAYER_FRACTION,
            rhyme_family1,
            rhyme_family2,
        )

        if not success:
            print(f"✗ Stage {stage_name} failed, stopping experiment")
            return False

    print(f"\n✓ All stages completed for {model_name}")
    return True


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Run all rhyme steering experiments")
    parser.add_argument("--output_dir", default="results", help="Output directory")
    parser.add_argument(
        "--resume",
        action="store_true",
        default=True,
        help="Resume from completed stages",
    )
    parser.add_argument(
        "--no-resume",
        dest="resume",
        action="store_false",
        help="Don't resume, run all stages",
    )
    parser.add_argument("--model", help="Run only specific model (default: all)")

    parser.add_argument(
        "--mode",
        default="rhyme_family_steering",
        help="Mode (rhyme_family_steering, specific_word_steering)",
    )
    parser.add_argument("--rhyme_family1", default=None, help="First rhyme family")
    parser.add_argument("--rhyme_family2", default=None, help="Second rhyme family")
    parser.add_argument(
        "--num_prompts",
        type=int,
        default=None,
        help="Number of prompts to generate (for debugging)",
    )
    parser.add_argument(
        "--LAYER_FRACTION", type=float, default=0.8, help="Model layer fraction to use"
    )
    parser.add_argument(
        "--strip",
        action="store_true",
        default=False,
        help="Ignore end character such as newline for steering",
    )

    args = parser.parse_args()

    # Filter models if specified
    models_to_run = MODEL_NAMES
    if args.model:
        if args.model in MODEL_NAMES:
            models_to_run = [args.model]
        else:
            print(f"Error: Model '{args.model}' not in {MODEL_NAMES}")
            return

    print(f"Starting experiments at {datetime.now()}")
    print(f"Models: {models_to_run}")
    print(f"Pairs: {RHYME_FAMILY_PAIRS}")
    print(f"Output directory: {args.output_dir}")
    print(f"Resume mode: {args.resume}")
    print(f"Mode: {args.mode}")
    total_experiments = len(models_to_run)
    completed_experiments = 0
    failed_experiments = []

    start_time = time.time()

    for model_name in models_to_run:
        success = run_experiment(
            model_name,
            args.output_dir,
            args.resume,
            args.mode,
            args.strip,
            args.LAYER_FRACTION,
            args.rhyme_family1,
            args.rhyme_family2,
        )

        if success:
            completed_experiments += 1
        else:
            failed_experiments.append((model_name))

    total_time = time.time() - start_time

    print(f"\n{'=' * 60}")
    print("SUMMARY")
    print(f"{'=' * 60}")
    print(f"Total experiments: {total_experiments}")
    print(f"Completed: {completed_experiments}")
    print(f"Failed: {len(failed_experiments)}")
    print(f"Total time: {total_time / 3600:.1f} hours")

    if failed_experiments:
        print("\nFailed experiments:")
        for model in failed_experiments:
            print(f"  {model}")

    print(f"\nFinished at {datetime.now()}")


if __name__ == "__main__":
    main()
