#!/usr/bin/env python3
"""
Run Ablation Experiments

Reproduces Tables 5, 6, 9, 10 from the paper.

Usage:
    # Run all ablations
    python scripts/run_ablations.py --config configs/gdo_dpo_llama3_8b.yaml --output_dir outputs/ablations --all

    # Run specific ablations
    python scripts/run_ablations.py --config configs/gdo_dpo_llama3_8b.yaml --output_dir outputs/ablations --curriculum
    python scripts/run_ablations.py --config configs/gdo_dpo_llama3_8b.yaml --output_dir outputs/ablations --monitoring
    python scripts/run_ablations.py --config configs/gdo_dpo_llama3_8b.yaml --output_dir outputs/ablations --sensitivity
"""

import argparse
import os
import sys
import yaml
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.ablation.ablation_experiments import (
    CurriculumAblations,
    MonitoringAblations,
    SensitivityAnalysis,
    AblationRunner
)
from src.core.gdo_dpo import GDODPOConfig


def parse_args():
    parser = argparse.ArgumentParser(description="Run ablation experiments")
    parser.add_argument("--config", type=str, required=True,
                       help="Path to base config file")
    parser.add_argument("--output_dir", type=str, default="outputs/ablations")

    # Which ablations to run
    parser.add_argument("--all", action="store_true",
                       help="Run all ablation experiments")
    parser.add_argument("--curriculum", action="store_true",
                       help="Run curriculum structure ablations (Table 5)")
    parser.add_argument("--monitoring", action="store_true",
                       help="Run monitoring mechanism ablations (Table 6)")
    parser.add_argument("--sensitivity", action="store_true",
                       help="Run sensitivity analysis (Tables 9, 10)")

    return parser.parse_args()


def load_config(config_path: str) -> dict:
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)


def main():
    args = parse_args()
    config = load_config(args.config)

    # Create base GDO-DPO config
    base_gdo_config = GDODPOConfig(
        tau_stable=config['gdo_dpo']['tau_stable'],
        tau_acc=config['gdo_dpo']['tau_acc'],
        delta_sem=config['gdo_dpo']['delta_sem'],
        delta_unc=config['gdo_dpo']['delta_unc'],
        layer_mid=config['gdo_dpo']['layer_mid'],
        ema_decay=config['gdo_dpo']['ema_decay'],
        eval_interval=config['gdo_dpo']['eval_interval'],
        beta=config['training']['beta'],
    )

    # Create ablation runner
    runner = AblationRunner(base_gdo_config)

    print("\n" + "="*60)
    print("GDO-DPO Ablation Experiments")
    print("="*60)

    # Determine which ablations to run
    run_all = args.all
    run_curriculum = args.curriculum or run_all
    run_monitoring = args.monitoring or run_all
    run_sensitivity = args.sensitivity or run_all

    # Note: The actual training function needs to be implemented
    # This is a template showing the structure
    def train_function(dataset, gdo_config, **kwargs):
        """
        Placeholder for training function.

        In practice, this should:
        1. Load model and reference model
        2. Create GDODPOTrainer with the given config
        3. Train the model
        4. Evaluate on benchmarks (MT-Bench, AlpacaEval, Arena-Hard)
        5. Return evaluation results
        """
        print(f"Training with config: {gdo_config}")
        print(f"Dataset size: {len(dataset)}")

        # TODO: Implement actual training
        # For now, return dummy results
        return {
            'mt_bench': 8.0,
            'alpacaeval': 30.0,
            'arena_hard': 20.0,
        }

    # Run curriculum ablations (Table 5)
    if run_curriculum:
        print("\n" + "="*60)
        print("Running Curriculum Structure Ablations (Table 5)")
        print("="*60)

        # Load dataset (placeholder)
        dataset = []  # TODO: Load actual dataset

        results = runner.run_curriculum_ablations(
            train_fn=train_function,
            dataset=dataset,
        )

        print("\nCurriculum Ablation Results:")
        for name, result in results.items():
            print(f"\n{name}:")
            for metric, score in result.items():
                print(f"  {metric}: {score:.2f}")

    # Run monitoring ablations (Table 6)
    if run_monitoring:
        print("\n" + "="*60)
        print("Running Monitoring Mechanism Ablations (Table 6)")
        print("="*60)

        dataset = []  # TODO: Load actual dataset
        total_steps = 1000  # TODO: Compute from config

        results = runner.run_monitoring_ablations(
            train_fn=train_function,
            dataset=dataset,
            total_steps=total_steps,
        )

        print("\nMonitoring Ablation Results:")
        for name, result in results.items():
            print(f"\n{name}:")
            for metric, score in result.items():
                print(f"  {metric}: {score:.2f}")

    # Run sensitivity analysis (Tables 9, 10)
    if run_sensitivity:
        print("\n" + "="*60)
        print("Running Sensitivity Analysis (Tables 9, 10)")
        print("="*60)

        dataset = []  # TODO: Load actual dataset
        num_layers = 32  # TODO: Get from model config

        results = runner.run_sensitivity_analysis(
            train_fn=train_function,
            dataset=dataset,
            num_layers=num_layers,
        )

        print("\nSensitivity Analysis Results:")

        print("\nLayer Boundary Sensitivity (Table 9):")
        for name, result in results['layer_boundary'].items():
            print(f"\n{name}:")
            for metric, score in result.items():
                print(f"  {metric}: {score:.2f}")

        print("\nThreshold Sensitivity (Table 10):")
        for name, result in results['thresholds'].items():
            print(f"\n{name}:")
            for metric, score in result.items():
                print(f"  {metric}: {score:.2f}")

    # Save all results
    output_path = os.path.join(args.output_dir, "ablation_results.json")
    os.makedirs(args.output_dir, exist_ok=True)
    runner.save_results(output_path)

    print("\n" + "="*60)
    print("Ablation Experiments Complete!")
    print(f"Results saved to {output_path}")
    print("="*60)


if __name__ == "__main__":
    main()
