#!/usr/bin/env python
"""Run multi-seed budgeted optimization experiments for statistical comparison.

This script orchestrates multiple runs of budgeted optimization across different
methods and random seeds, enabling robust statistical comparison with confidence intervals.

Stages:
    run_seeds: Execute all method x seed combinations
    aggregate: Collect and combine results
    plot: Generate all visualizations
    report: Generate summary tables with statistics
    all: Run all stages sequentially

Example usage:
    # Run all stages
    python scripts/run_multi_seed_experiment.py --config configs/experiments/multi_seed_comparison.yaml

    # Run only seed experiments with parallel execution
    python scripts/run_multi_seed_experiment.py --stages run_seeds --workers 4

    # Generate plots from existing runs
    python scripts/run_multi_seed_experiment.py --stages plot,report --log-dir experiments/multi_seed_comparison/

    # Quick test with fewer seeds
    python scripts/run_multi_seed_experiment.py --n-seeds 3 --methods moltenflow,gradient_ascent
"""

import argparse
import json
import shutil
import subprocess
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Sequence

from tqdm import tqdm

from moltenflow.optimization import (
    format_summary_table,
    generate_summary_table,
    plot_all_figures,
)
from moltenflow.utils.config import load_yaml
from moltenflow.utils.logging import get_logger

logger = get_logger(__name__)

# All available stages
ALL_STAGES = ["run_seeds", "aggregate", "plot", "report"]


def parse_stages(stages_str: str) -> list[str]:
    """Parse comma-separated stages string."""
    if stages_str == "all":
        return ALL_STAGES
    return [s.strip() for s in stages_str.split(",")]


def run_single_optimization(
    method: str,
    seed: int,
    config_path: str,
    output_dir: Path,
    init_method: str = "random",
    budget: int | None = None,
) -> tuple[str, int, bool, str, float]:
    """Run a single optimization with given method and seed.

    Args:
        method: Optimization method
        seed: Random seed
        config_path: Path to config file
        output_dir: Output directory for this run
        init_method: Initialization method
        budget: Oracle budget (overrides config if set)

    Returns:
        Tuple of (method, seed, success, message, runtime_seconds)
    """
    run_dir = output_dir / f"{method}_{init_method}_seed{seed}"
    timing_path = run_dir / "timing.json"

    # Check if already completed
    log_path = run_dir / "optimization_log.csv"
    summary_path = run_dir / "summary.json"
    if log_path.exists() and summary_path.exists():
        # Load existing runtime if available
        runtime = 0.0
        if timing_path.exists():
            with open(timing_path) as f:
                timing_data = json.load(f)
                runtime = timing_data.get("runtime_seconds", 0.0)
        logger.info(f"Skipping completed run: {method} seed={seed}")
        return (method, seed, True, "Already completed", runtime)

    # Build command
    cmd = [
        sys.executable,
        "scripts/run_budgeted_optimization.py",
        "--config",
        config_path,
        "--method",
        method,
        "--seed",
        str(seed),
        "--init",
        init_method,
        "--output-dir",
        str(run_dir),
    ]

    if budget is not None:
        cmd.extend(["--budget", str(budget)])

    try:
        # Run subprocess with timing
        start_time = time.time()
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            check=True,
        )
        runtime = time.time() - start_time

        # Save timing information
        run_dir.mkdir(parents=True, exist_ok=True)
        with open(timing_path, "w") as f:
            json.dump(
                {
                    "method": method,
                    "seed": seed,
                    "runtime_seconds": runtime,
                    "runtime_minutes": runtime / 60,
                },
                f,
                indent=2,
            )

        return (method, seed, True, "Success", runtime)
    except subprocess.CalledProcessError as e:
        runtime = time.time() - start_time if "start_time" in locals() else 0.0
        error_msg = f"Failed: {e.stderr[:200]}"
        logger.error(f"Run failed for {method} seed={seed}: {error_msg}")
        return (method, seed, False, error_msg, runtime)


def stage_run_seeds(
    config: dict,
    output_dir: Path,
    config_path: str,
    methods: Sequence[str] | None = None,
    seeds: Sequence[int] | None = None,
    workers: int = 1,
) -> dict[str, list[dict]]:
    """Stage: Run all method x seed combinations.

    Args:
        config: Experiment configuration
        output_dir: Output directory
        config_path: Path to config file
        methods: Methods to run (default: from config)
        seeds: Seeds to run (default: from config)
        workers: Number of parallel workers

    Returns:
        Dictionary with run results
    """
    logger.info("=" * 70)
    logger.info("STAGE: Run Seeds")
    logger.info("=" * 70)

    # Get methods and seeds from config or arguments
    exp_config = config.get("experiment", {})
    if methods is None:
        methods = exp_config.get("methods", ["moltenflow", "gradient_ascent", "bo_mogp", "bo_2gp"])
    if seeds is None:
        seeds = exp_config.get("seeds", [42, 123, 456, 789, 1000])

    init_method = config.get("init", {}).get("method", "random")
    budget = config.get("optimization", {}).get("budget")

    # Create all combinations
    combinations = [(m, s) for m in methods for s in seeds]
    n_total = len(combinations)

    logger.info(f"Running {n_total} experiments:")
    logger.info(f"  Methods: {methods}")
    logger.info(f"  Seeds: {seeds}")
    logger.info(f"  Workers: {workers}")

    # Run experiments
    results = []

    if workers == 1:
        # Sequential execution with progress bar
        for method, seed in tqdm(combinations, desc="Running experiments"):
            result = run_single_optimization(
                method, seed, config_path, output_dir, init_method, budget
            )
            results.append(result)
    else:
        # Parallel execution
        with ProcessPoolExecutor(max_workers=workers) as executor:
            # Submit all jobs
            futures = {
                executor.submit(
                    run_single_optimization,
                    method,
                    seed,
                    config_path,
                    output_dir,
                    init_method,
                    budget,
                ): (method, seed)
                for method, seed in combinations
            }

            # Collect results with progress bar
            for future in tqdm(as_completed(futures), total=n_total, desc="Running experiments"):
                result = future.result()
                results.append(result)

    # Summarize results
    successful = sum(1 for _, _, success, _, _ in results if success)
    failed = n_total - successful
    total_runtime = sum(runtime for _, _, _, _, runtime in results)

    logger.info(f"Completed: {successful}/{n_total} successful, {failed} failed")
    logger.info(f"Total runtime: {total_runtime / 60:.1f} minutes")

    # Group results
    run_results = {
        "total": n_total,
        "successful": successful,
        "failed": failed,
        "total_runtime_seconds": total_runtime,
        "runs": [
            {"method": m, "seed": s, "success": success, "message": msg, "runtime_seconds": runtime}
            for m, s, success, msg, runtime in results
        ],
    }

    return run_results


def stage_aggregate(output_dir: Path) -> dict:
    """Stage: Aggregate results from all runs.

    Args:
        output_dir: Directory containing run subdirectories

    Returns:
        Dictionary with aggregated results
    """
    logger.info("=" * 70)
    logger.info("STAGE: Aggregate Results")
    logger.info("=" * 70)

    from moltenflow.optimization.logger import load_experiment_logs

    # Load all logs
    logs = load_experiment_logs(output_dir)

    if not logs:
        logger.warning(f"No logs found in {output_dir}")
        return {}

    logger.info(f"Loaded {len(logs)} optimization runs")

    # Save aggregated results
    agg_path = output_dir / "aggregated_results.json"
    with open(agg_path, "w") as f:
        json.dump(
            {
                "n_runs": len(logs),
                "run_ids": list(logs.keys()),
            },
            f,
            indent=2,
        )

    logger.info(f"Saved aggregated results to {agg_path}")

    return {"n_runs": len(logs), "logs": logs}


def stage_plot(
    output_dir: Path,
    config: dict,
    methods: Sequence[str] | None = None,
) -> dict[str, Path]:
    """Stage: Generate all visualizations.

    Args:
        output_dir: Directory containing run subdirectories
        config: Experiment configuration
        methods: Methods to include (default: all)

    Returns:
        Dictionary mapping figure name to output path
    """
    logger.info("=" * 70)
    logger.info("STAGE: Generate Plots")
    logger.info("=" * 70)

    figures_dir = output_dir / "figures"
    figures_dir.mkdir(exist_ok=True, parents=True)

    # Get plotting config
    plot_cfg = config.get("plotting", {})
    confidence = plot_cfg.get("confidence", 0.95)
    n_bootstrap = plot_cfg.get("n_bootstrap", 1000)
    ci_type = plot_cfg.get("ci_type", "bootstrap")

    # Generate all standard figures
    init_method = config.get("init", {}).get("method", "random")
    budget = config.get("optimization", {}).get("budget")

    logger.info(f"Generating standard figures (ci_type={ci_type}, confidence={confidence})...")
    outputs = plot_all_figures(
        log_dir=output_dir,
        output_dir=figures_dir,
        init=init_method,
        budget=budget,
        methods=methods,
        format="pdf",
        n_bootstrap=n_bootstrap,
        confidence=confidence,
        ci_type=ci_type,
    )

    # Generate additional figures for multi-seed analysis
    from moltenflow.optimization.plotting import (
        plot_example_runs,
        plot_pareto_density,
    )

    # Pareto density plot
    logger.info("Generating Pareto density plot...")
    density_path = figures_dir / f"pareto_density_{init_method}.pdf"
    try:
        plot_pareto_density(
            output_dir,
            density_path,
            methods=methods,
            grid_resolution=plot_cfg.get("density_grid_resolution", 50),
        )
        outputs["pareto_density"] = density_path
    except Exception as e:
        logger.error(f"Failed to generate density plot: {e}")

    # Example runs grid
    logger.info("Generating example runs grid...")
    example_cfg = plot_cfg.get("example_runs", {})
    n_best = example_cfg.get("n_best", 2)
    n_worst = example_cfg.get("n_worst", 2)

    examples_path = figures_dir / f"pareto_examples_{init_method}.pdf"
    try:
        plot_example_runs(
            output_dir,
            examples_path,
            methods=methods,
            n_best=n_best,
            n_worst=n_worst,
        )
        outputs["pareto_examples"] = examples_path
    except Exception as e:
        logger.error(f"Failed to generate example runs: {e}")

    logger.info(f"Generated {len(outputs)} figures in {figures_dir}")

    return outputs


def stage_report(
    output_dir: Path,
    config: dict,
    methods: Sequence[str] | None = None,
) -> dict:
    """Stage: Generate summary tables and statistical tests.

    Args:
        output_dir: Directory containing run subdirectories
        config: Experiment configuration
        methods: Methods to include (default: all)

    Returns:
        Dictionary with report results
    """
    logger.info("=" * 70)
    logger.info("STAGE: Generate Report")
    logger.info("=" * 70)

    figures_dir = output_dir / "figures"
    figures_dir.mkdir(exist_ok=True, parents=True)

    # Get plotting config
    plot_cfg = config.get("plotting", {})
    confidence = plot_cfg.get("confidence", 0.95)
    n_bootstrap = plot_cfg.get("n_bootstrap", 1000)

    # Generate summary table
    logger.info("Generating summary table...")
    df = generate_summary_table(
        output_dir,
        methods=methods,
        n_bootstrap=n_bootstrap,
        confidence=confidence,
    )

    if len(df) == 0:
        logger.warning("No results to summarize")
        return {}

    # Print summary
    print("\n" + format_summary_table(df, confidence))

    # Save summary CSV
    summary_path = figures_dir / "summary.csv"
    df.to_csv(summary_path, index=False)
    logger.info(f"Saved summary to {summary_path}")

    # Perform statistical tests
    logger.info("Performing pairwise statistical tests...")
    from moltenflow.optimization.summary import compute_pairwise_tests

    try:
        test_results = compute_pairwise_tests(output_dir, methods=methods)

        # Save test results
        tests_path = figures_dir / "statistical_tests.json"
        with open(tests_path, "w") as f:
            json.dump(test_results, f, indent=2)

        logger.info(f"Saved statistical tests to {tests_path}")

        # Print significant differences
        print("\nStatistical Tests (Mann-Whitney U):")
        print("=" * 60)
        for comparison, result in test_results.items():
            if result["p_value"] < 0.05:
                print(f"{comparison}: p={result['p_value']:.4f} *")
            else:
                print(f"{comparison}: p={result['p_value']:.4f}")
        print("=" * 60)
        print("* p < 0.05 (significant)")

    except Exception as e:
        logger.error(f"Failed to compute statistical tests: {e}")
        test_results = {}

    return {
        "summary": df.to_dict("records"),
        "statistical_tests": test_results,
    }


def main():
    parser = argparse.ArgumentParser(description="Run multi-seed budgeted optimization experiments")
    parser.add_argument(
        "--config",
        type=str,
        default="configs/experiments/multi_seed_comparison.yaml",
        help="Path to experiment config",
    )
    parser.add_argument(
        "--stages",
        type=str,
        default="all",
        help="Comma-separated stages to run (run_seeds,aggregate,plot,report,all)",
    )
    parser.add_argument(
        "--log-dir",
        type=str,
        default=None,
        help="Log directory (default: from config)",
    )
    parser.add_argument(
        "--methods",
        type=str,
        default=None,
        help="Comma-separated methods to include (default: from config)",
    )
    parser.add_argument(
        "--n-seeds",
        type=int,
        default=None,
        help="Number of seeds (overrides config)",
    )
    parser.add_argument(
        "--seed-start",
        type=int,
        default=42,
        help="Starting seed for auto-generated seed list",
    )
    parser.add_argument(
        "--workers",
        type=int,
        default=1,
        help="Number of parallel workers for run_seeds stage",
    )

    args = parser.parse_args()

    # Load config
    config = load_yaml(args.config)
    config_path = args.config

    # Parse stages
    stages = parse_stages(args.stages)
    logger.info(f"Running stages: {stages}")

    # Determine output directory
    if args.log_dir:
        output_dir = Path(args.log_dir)
    else:
        exp_name = config.get("experiment", {}).get("name", "multi_seed_comparison")
        output_dir = Path(config.get("output", {}).get("dir", f"experiments/{exp_name}"))

    output_dir.mkdir(parents=True, exist_ok=True)
    logger.info(f"Output directory: {output_dir}")

    # Parse methods
    methods = None
    if args.methods:
        methods = [m.strip() for m in args.methods.split(",")]

    # Generate seeds if n_seeds specified
    seeds = None
    if args.n_seeds is not None:
        seeds = list(range(args.seed_start, args.seed_start + args.n_seeds * 100, 100))

    # Copy config to output directory
    if "run_seeds" in stages:
        config_copy = output_dir / "config.yaml"
        if not config_copy.exists():
            shutil.copy(config_path, config_copy)
            logger.info(f"Copied config to {config_copy}")

    # Run stages
    results = {}

    if "run_seeds" in stages:
        results["run_seeds"] = stage_run_seeds(
            config, output_dir, config_path, methods, seeds, args.workers
        )

    if "aggregate" in stages:
        results["aggregate"] = stage_aggregate(output_dir)

    if "plot" in stages:
        results["plot"] = stage_plot(output_dir, config, methods)

    if "report" in stages:
        results["report"] = stage_report(output_dir, config, methods)

    logger.info("=" * 70)
    logger.info("PIPELINE COMPLETE")
    logger.info("=" * 70)
    logger.info(f"Results saved to: {output_dir}")

    # Print summary
    if "run_seeds" in results:
        rs = results["run_seeds"]
        logger.info(f"Runs: {rs['successful']}/{rs['total']} successful")

    if "plot" in results:
        logger.info(f"Figures: {len(results['plot'])} generated")


if __name__ == "__main__":
    main()
