import importlib.util
import json
import os
from pathlib import Path
import shutil
import subprocess
import sys
from types import SimpleNamespace

import numpy as np
import pytest

os.environ.setdefault("OPENAI_API_KEY", "test")

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from openevolve.config import Config
from openevolve.database import Program, ProgramDatabase
from openevolve.multi_task_shared_then_specialize.robust_regression import (
    DEFAULT_FULL_TIMEOUT_SECONDS,
    FULL_EVAL_SEEDS,
    ROBUST_REGRESSION_TASK_SELECTOR_ENV_VAR,
    ROBUST_REGRESSION_TASK_SPECS,
    STAGE1_SEEDS,
    aggregate_seed_results,
    aggregate_task_results,
    build_task_result,
    extract_task_result,
    generate_regression_dataset,
)
import openevolve.multi_task_shared_then_specialize.spawn as mt_sts_spawn
from openevolve.multi_task_shared_then_specialize.spawn import spawn_task_checkpoints
from openevolve.multi_task_shared_then_specialize.workflow import (
    build_mt_sts_wandb_namespace,
    build_mt_sts_wandb_run_id,
    build_phase_wandb_config,
    fair_mt_sts_baseline_iterations,
    load_manifest,
    phase_checkpoint_status,
    resolve_mt_sts_wandb_run_id,
    validate_mt_sts_iteration_budget,
    write_phase_config,
)


def _load_rr_evaluator_module():
    evaluator_path = REPO_ROOT / "examples" / "r_robust_regression" / "evaluator.py"
    spec = importlib.util.spec_from_file_location("rr_evaluator_test_module", evaluator_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Could not load evaluator from {evaluator_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def _load_mt_sts_runner_module():
    runner_path = (
        REPO_ROOT / "multi_task_shared_then_adapt" / "run_multi_task_shared_then_specialize.py"
    )
    spec = importlib.util.spec_from_file_location("mt_sts_runner_test_module", runner_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Could not load MT-STS runner from {runner_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


def _load_mt_sts_report_module():
    report_path = REPO_ROOT / "multi_task_shared_then_adapt" / "report_mt_sts_results.py"
    spec = importlib.util.spec_from_file_location("mt_sts_report_test_module", report_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Could not load MT-STS report helper from {report_path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


rr_evaluator = _load_rr_evaluator_module()
mt_sts_runner = _load_mt_sts_runner_module()
mt_sts_report = _load_mt_sts_report_module()

ROBUST_REGRESSION_TASK_IDS = [task.task_id for task in ROBUST_REGRESSION_TASK_SPECS]
FIRST_ROBUST_REGRESSION_TASK_ID = ROBUST_REGRESSION_TASK_IDS[0]


def _build_seed_result(
    task,
    base_seed: int,
    *,
    prediction_mode: str = "perfect_signal",
    runtime: float = 0.01,
):
    dataset = generate_regression_dataset(task, base_seed)
    if prediction_mode == "perfect_signal":
        predictions = dataset.y_test_clean_signal
        coefficients = dataset.true_coefficients
    elif prediction_mode == "perfect_noisy":
        predictions = dataset.y_test_noisy
        coefficients = dataset.true_coefficients
    elif prediction_mode == "zeros":
        predictions = np.zeros(task.n_test, dtype=float)
        coefficients = np.zeros(task.n_features + 1, dtype=float)
    else:
        raise ValueError(f"Unknown prediction_mode: {prediction_mode}")
    return rr_evaluator.compute_seed_metrics(
        np.asarray(predictions, dtype=float),
        np.asarray(coefficients, dtype=float),
        dataset,
        runtime=runtime,
    )


def _build_task_result_from_seed_results(task, seed_results):
    aggregated_metrics = aggregate_seed_results(
        seed_results,
        seed_count=len(seed_results),
        timeout_seconds=task.trial_timeout_seconds_full,
    )
    return build_task_result(
        task,
        raw_metrics=aggregated_metrics,
        seed_results=seed_results,
        timeout_seconds=task.trial_timeout_seconds_full,
    )


def _build_full_task_results(*, program_index: int) -> list[dict]:
    task_results = []
    for task in ROBUST_REGRESSION_TASK_SPECS:
        prediction_mode = "perfect_signal" if program_index == 0 else "perfect_noisy"
        seed_results = [
            _build_seed_result(task, base_seed, prediction_mode=prediction_mode)
            for base_seed in FULL_EVAL_SEEDS
        ]
        task_results.append(_build_task_result_from_seed_results(task, seed_results))
    return task_results


def test_robust_regression_family_uses_expected_task_ids_and_order():
    assert ROBUST_REGRESSION_TASK_IDS == [
        "rr_outliers10_100x3",
        "rr_outliers20_100x3",
        "rr_leverage10_100x3",
        "rr_hard_120x8",
    ]

    hard_task = ROBUST_REGRESSION_TASK_SPECS[-1]
    assert hard_task.task_id == "rr_hard_120x8"
    assert hard_task.n_train == 120
    assert hard_task.n_test == 400
    assert hard_task.n_features == 8
    assert hard_task.noise_std == pytest.approx(0.20)
    assert hard_task.rho == pytest.approx(0.85)
    assert hard_task.vertical_outlier_fraction_train == pytest.approx(0.20)
    assert hard_task.leverage_outlier_fraction_train == pytest.approx(0.15)
    assert hard_task.hetero_strength == pytest.approx(1.0)


def test_mt_sts_single_run_wandb_config_reuses_one_run_id(tmp_path):
    manifest = load_manifest(
        REPO_ROOT / "multi_task_shared_then_adapt" / "r_robust_regression_mt_sts.yaml"
    )
    run_root = tmp_path / "mt_sts_demo"
    expected_run_id = resolve_mt_sts_wandb_run_id(run_root)

    shared_wandb = build_phase_wandb_config(
        manifest,
        run_name="demo",
        run_root=run_root,
        wandb_run_id=expected_run_id,
        phase="shared",
        run_label="demo",
        shared_iterations=20,
        adaptation_iterations=20,
        baseline_iterations=25,
    )
    adaptation_wandb = build_phase_wandb_config(
        manifest,
        run_name="demo",
        run_root=run_root,
        wandb_run_id=expected_run_id,
        phase="adaptation",
        task_id=FIRST_ROBUST_REGRESSION_TASK_ID,
        run_label="demo",
        shared_iterations=20,
        adaptation_iterations=20,
        baseline_iterations=25,
    )

    assert manifest.wandb_single_run is True
    assert shared_wandb["run_id"] == expected_run_id
    assert adaptation_wandb["run_id"] == expected_run_id
    assert shared_wandb["namespace"] == build_mt_sts_wandb_namespace(phase="shared")
    assert adaptation_wandb["namespace"] == build_mt_sts_wandb_namespace(
        phase="adaptation",
        task_id=FIRST_ROBUST_REGRESSION_TASK_ID,
    )

    config_path = write_phase_config(
        base_config_path=manifest.base_config,
        output_config_path=tmp_path / "shared_config.yaml",
        iterations=5,
        wandb_config=shared_wandb,
    )
    config = Config.from_yaml(config_path)
    assert config.wandb.run_id == expected_run_id
    assert config.wandb.namespace == "shared"


def test_robust_regression_config_timeout_covers_full_seeded_evaluation_budget():
    config_path = REPO_ROOT / "examples" / "r_robust_regression" / "config.yaml"
    config = Config.from_yaml(config_path)
    full_eval_budget_seconds = (
        len(ROBUST_REGRESSION_TASK_SPECS) * len(FULL_EVAL_SEEDS) * DEFAULT_FULL_TIMEOUT_SECONDS
    )

    assert config.evaluator.timeout > full_eval_budget_seconds


def test_iteration_fair_baseline_helper_and_validator():
    assert fair_mt_sts_baseline_iterations(
        task_count=4,
        shared_iterations=44,
        adaptation_iterations=20,
    ) == 31
    assert (
        fair_mt_sts_baseline_iterations(
            task_count=4,
            shared_iterations=2,
            adaptation_iterations=2,
        )
        is None
    )

    validate_mt_sts_iteration_budget(
        task_count=4,
        shared_iterations=44,
        adaptation_iterations=20,
        baseline_iterations=31,
    )

    with pytest.raises(ValueError, match="Unsafe MT-STS iteration setting"):
        validate_mt_sts_iteration_budget(
            task_count=4,
            shared_iterations=60,
            adaptation_iterations=20,
            baseline_iterations=31,
        )


def test_phase_checkpoint_status_requires_iteration_budget_and_best_info(tmp_path):
    output_dir = tmp_path / "adaptation" / FIRST_ROBUST_REGRESSION_TASK_ID
    checkpoint_5 = output_dir / "checkpoints" / "checkpoint_5"
    checkpoint_20 = output_dir / "checkpoints" / "checkpoint_20"
    checkpoint_5.mkdir(parents=True, exist_ok=True)

    is_complete, resume_checkpoint = phase_checkpoint_status(
        output_dir,
        20,
        require_best_info=True,
    )
    assert is_complete is False
    assert resume_checkpoint == checkpoint_5

    checkpoint_20.mkdir(parents=True, exist_ok=True)
    is_complete_without_best, latest_checkpoint = phase_checkpoint_status(
        output_dir,
        20,
        require_best_info=False,
    )
    assert is_complete_without_best is True
    assert latest_checkpoint == checkpoint_20

    best_info_path = output_dir / "best" / "best_program_info.json"
    best_info_path.parent.mkdir(parents=True, exist_ok=True)
    best_info_path.write_text("{}", encoding="utf-8")
    is_complete_with_best, latest_checkpoint = phase_checkpoint_status(
        output_dir,
        20,
        require_best_info=True,
    )
    assert is_complete_with_best is True
    assert latest_checkpoint == checkpoint_20


def test_evaluator_ignores_fake_self_reported_metrics(monkeypatch):
    task = ROBUST_REGRESSION_TASK_SPECS[0]

    def fake_run(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        return (
            {
                "predictions": [1e6] * task.n_test,
                "coefficients": [1e6] * (task.n_features + 1),
                "mse": 0.0,
                "mae": 0.0,
                "r_squared": 1.0,
                "combined_score": 1.0,
            },
            None,
            0.01,
        )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", fake_run)

    task_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=(0,),
        timeout_seconds=1.0,
    )

    assert task_result["metrics"]["successful_seed_count"] == 1
    assert task_result["metrics"]["combined_score"] == pytest.approx(
        task_result["metrics"]["score"]
    )
    assert task_result["metrics"]["score"] < 1e-6


def test_train_test_split_is_used(monkeypatch):
    task = ROBUST_REGRESSION_TASK_SPECS[0]

    def perfect_run(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        return (
            {
                "predictions": dataset.y_test_clean_signal.tolist(),
                "coefficients": dataset.true_coefficients.tolist(),
            },
            None,
            0.01,
        )

    def train_memorizer_run(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        return (
            {
                "predictions": np.resize(dataset.y_train, task.n_test).tolist(),
                "coefficients": dataset.true_coefficients.tolist(),
            },
            None,
            0.01,
        )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", perfect_run)
    perfect_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=(0,),
        timeout_seconds=1.0,
    )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", train_memorizer_run)
    train_memorizer_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=(0,),
        timeout_seconds=1.0,
    )

    assert perfect_result["metrics"]["score"] > train_memorizer_result["metrics"]["score"] + 0.2


def test_multiple_seeds_are_aggregated(monkeypatch):
    task = ROBUST_REGRESSION_TASK_SPECS[0]

    def mixed_run(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        if dataset.base_seed in (0, 1):
            predictions = dataset.y_test_clean_signal
            coefficients = dataset.true_coefficients
        else:
            predictions = np.zeros(task.n_test, dtype=float)
            coefficients = np.zeros(task.n_features + 1, dtype=float)
        return (
            {
                "predictions": predictions.tolist(),
                "coefficients": coefficients.tolist(),
            },
            None,
            0.02,
        )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", mixed_run)

    task_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=FULL_EVAL_SEEDS,
        timeout_seconds=1.0,
    )

    expected_seed_results = []
    for base_seed in FULL_EVAL_SEEDS:
        dataset = generate_regression_dataset(task, base_seed)
        if base_seed in (0, 1):
            predictions = dataset.y_test_clean_signal
            coefficients = dataset.true_coefficients
        else:
            predictions = np.zeros(task.n_test, dtype=float)
            coefficients = np.zeros(task.n_features + 1, dtype=float)
        expected_seed_results.append(
            rr_evaluator.compute_seed_metrics(
                np.asarray(predictions, dtype=float),
                np.asarray(coefficients, dtype=float),
                dataset,
                runtime=0.02,
            )
        )
    expected_metrics = aggregate_seed_results(
        expected_seed_results,
        seed_count=len(FULL_EVAL_SEEDS),
        timeout_seconds=1.0,
    )

    assert task_result["metrics"]["seed_count"] == len(FULL_EVAL_SEEDS)
    assert len(task_result["seed_results"]) == len(FULL_EVAL_SEEDS)
    assert task_result["metrics"]["successful_seed_count"] == len(FULL_EVAL_SEEDS)
    assert task_result["metrics"]["score"] == pytest.approx(expected_metrics["score"])
    assert task_result["metrics"]["signal_score"] == pytest.approx(
        expected_metrics["signal_score"]
    )


def test_missing_predictions_fail_gracefully(monkeypatch):
    task = ROBUST_REGRESSION_TASK_SPECS[0]

    def missing_predictions(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, dataset, timeout_seconds
        return ({"coefficients": [0.0] * (task.n_features + 1)}, None, 0.01)

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", missing_predictions)

    task_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=(0,),
        timeout_seconds=1.0,
    )

    assert task_result["metrics"]["successful_seed_count"] == 0
    assert task_result["metrics"]["success_rate"] == pytest.approx(0.0)
    assert task_result["metrics"]["score"] == pytest.approx(0.0)
    assert "predictions" in (task_result["error"] or "")


def test_wrong_coefficient_length_fails_gracefully(monkeypatch):
    task = ROBUST_REGRESSION_TASK_SPECS[0]

    def wrong_coefficients(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        return (
            {
                "predictions": dataset.y_test_clean_signal.tolist(),
                "coefficients": [0.0] * task.n_features,
            },
            None,
            0.01,
        )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", wrong_coefficients)

    task_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=(0,),
        timeout_seconds=1.0,
    )

    assert task_result["metrics"]["successful_seed_count"] == 0
    assert task_result["metrics"]["combined_score"] == pytest.approx(0.0)
    assert "coefficients length" in (task_result["error"] or "")


def test_stage1_uses_reduced_seed_bank(monkeypatch):
    selected_task = ROBUST_REGRESSION_TASK_SPECS[1]

    def perfect_run(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        return (
            {
                "predictions": dataset.y_test_clean_signal.tolist(),
                "coefficients": dataset.true_coefficients.tolist(),
            },
            None,
            0.01,
        )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", perfect_run)
    monkeypatch.setenv(ROBUST_REGRESSION_TASK_SELECTOR_ENV_VAR, selected_task.task_id)

    stage1_result = rr_evaluator.evaluate_stage1("unused_program.r")
    full_result = rr_evaluator.evaluate("unused_program.r")

    assert stage1_result.metrics["seed_count"] == len(STAGE1_SEEDS)
    assert full_result.metrics["seed_count"] == len(FULL_EVAL_SEEDS)
    assert stage1_result.metrics["combined_score"] == pytest.approx(stage1_result.metrics["score"])
    assert stage1_result.artifacts["evaluation_stage"] == "stage1"
    assert full_result.artifacts["evaluation_stage"] == "full"


def test_shared_mode_returns_aggregate_metrics_and_task_artifacts(monkeypatch):
    def perfect_run(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        return (
            {
                "predictions": dataset.y_test_clean_signal.tolist(),
                "coefficients": dataset.true_coefficients.tolist(),
            },
            None,
            0.01,
        )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", perfect_run)
    monkeypatch.setenv(ROBUST_REGRESSION_TASK_SELECTOR_ENV_VAR, "all")

    result = rr_evaluator.evaluate("unused_program.r")

    assert result.metrics["combined_score"] == pytest.approx(result.metrics["score"])
    assert result.metrics["task_count"] == pytest.approx(4.0)
    assert len(result.artifacts["task_results"]) == 4
    for task_result in result.artifacts["task_results"]:
        assert task_result["metrics"]["combined_score"] == pytest.approx(
            task_result["metrics"]["score"]
        )
        assert len(task_result["seed_results"]) == len(FULL_EVAL_SEEDS)


def test_report_helper_uses_new_task_order(tmp_path):
    manifest_path = (
        REPO_ROOT / "multi_task_shared_then_adapt" / "r_robust_regression_mt_sts.yaml"
    )
    manifest = load_manifest(manifest_path)
    run_root = tmp_path / "mt_sts_demo"
    configs_root = run_root / "configs"
    configs_root.mkdir(parents=True, exist_ok=True)
    (configs_root / "shared_config.yaml").write_text(
        "\n".join(
            [
                "max_iterations: 20",
                "llm:",
                "  primary_model: claude-opus-4-6",
                "diff_based_evolution: false",
            ]
        )
        + "\n",
        encoding="utf-8",
    )
    (configs_root / f"adaptation_{FIRST_ROBUST_REGRESSION_TASK_ID}.yaml").write_text(
        "max_iterations: 20\n",
        encoding="utf-8",
    )
    (configs_root / f"baseline_{FIRST_ROBUST_REGRESSION_TASK_ID}.yaml").write_text(
        "max_iterations: 25\n",
        encoding="utf-8",
    )

    run_report = mt_sts_report.load_run_report(
        run_root,
        repo_root=REPO_ROOT,
        manifest_path=manifest_path,
        manifest_family=manifest.family,
        task_specs=mt_sts_report.family_task_specs(manifest),
        wandb_entity_override=None,
    )

    assert list(run_report["tasks"].keys()) == ROBUST_REGRESSION_TASK_IDS
    assert run_report["tasks"]["rr_leverage10_100x3"]["task_spec"] == {
        "n_train": 100,
        "n_test": 200,
        "n_features": 3,
        "noise_std": 0.10,
        "rho": 0.0,
        "vertical_outlier_fraction_train": 0.0,
        "leverage_outlier_fraction_train": 0.10,
        "hetero_strength": 0.0,
    }


def test_extract_task_result_is_strict_on_malformed_metrics():
    task = ROBUST_REGRESSION_TASK_SPECS[0]
    valid_seed_results = [
        _build_seed_result(task, base_seed, prediction_mode="perfect_signal")
        for base_seed in FULL_EVAL_SEEDS
    ]
    valid_task_result = _build_task_result_from_seed_results(task, valid_seed_results)

    assert extract_task_result(
        {
            "evaluation_stage": "full",
            "task_results": [
                {
                    "task_id": task.task_id,
                    "metrics": "bad",
                    "seed_results": valid_seed_results,
                }
            ],
        },
        task.task_id,
    ) is None

    malformed_metrics = dict(valid_task_result["metrics"])
    malformed_metrics.pop("coef_score")
    assert extract_task_result(
        {
            "evaluation_stage": "full",
            "task_results": [
                {
                    "task_id": task.task_id,
                    "metrics": malformed_metrics,
                    "seed_results": valid_seed_results,
                }
            ],
        },
        task.task_id,
    ) is None

    assert extract_task_result(
        {
            "evaluation_stage": "full",
            "task_results": [
                {
                    "task_id": task.task_id,
                    "metrics": valid_task_result["metrics"],
                    "seed_results": valid_seed_results[:-1],
                }
            ],
        },
        task.task_id,
    ) is None


def test_extract_task_result_rejects_stage1_artifacts():
    task = ROBUST_REGRESSION_TASK_SPECS[0]
    seed_results = [
        _build_seed_result(task, base_seed, prediction_mode="perfect_signal")
        for base_seed in STAGE1_SEEDS
    ]
    task_result = _build_task_result_from_seed_results(task, seed_results)

    assert extract_task_result(
        {
            "evaluation_stage": "stage1",
            "task_results": [task_result],
        },
        task.task_id,
    ) is None


def test_spawn_uses_stored_task_results_without_reevaluation(tmp_path, monkeypatch):
    base_config_path = REPO_ROOT / "examples" / "r_robust_regression" / "config.yaml"
    evaluation_file = REPO_ROOT / "examples" / "r_robust_regression" / "evaluator.py"

    config = Config.from_yaml(base_config_path)
    config.database.db_path = None
    shared_database = ProgramDatabase(config.database)

    source_metrics_by_program: dict[str, dict[str, dict[str, float]]] = {}
    for program_index in range(2):
        task_results = _build_full_task_results(program_index=program_index)
        metrics = aggregate_task_results(task_results)
        program = Program(
            id=f"program_{program_index}",
            code=f"# program {program_index}\nmain <- function(X_train, y_train, X_test) list()\n",
            changes_description=f"program {program_index}",
            language="r",
            generation=program_index,
            iteration_found=program_index,
            metrics=metrics,
            metadata={"island": program_index % config.database.num_islands},
            artifacts_json=json.dumps(
                {
                    "task_selector": "all",
                    "evaluation_stage": "full",
                    "task_results": task_results,
                }
            ),
        )
        shared_database.add(program, target_island=program.metadata["island"])
        source_metrics_by_program[program.id] = {
            task_result["task_id"]: task_result["metrics"] for task_result in task_results
        }

    shared_checkpoint = tmp_path / "shared_checkpoint"
    shared_database.save(str(shared_checkpoint), iteration=5)

    def fail_if_reevaluated(**kwargs):
        raise AssertionError("Spawn should use stored task_results without reevaluation")

    monkeypatch.setattr(
        "openevolve.multi_task_shared_then_specialize.spawn._reevaluate_program_for_task",
        fail_if_reevaluated,
    )

    spawned_root = tmp_path / "spawned"
    spawn_results = spawn_task_checkpoints(
        shared_checkpoint_path=shared_checkpoint,
        output_root=spawned_root,
        base_config_path=base_config_path,
        evaluation_file=evaluation_file,
        task_ids=[FIRST_ROBUST_REGRESSION_TASK_ID],
    )

    assert spawn_results[FIRST_ROBUST_REGRESSION_TASK_ID]["reevaluated_program_ids"] == []

    spawned_config = Config.from_yaml(base_config_path)
    spawned_config.database.db_path = None
    spawned_database = ProgramDatabase(spawned_config.database)
    spawned_database.load(str(spawned_root / FIRST_ROBUST_REGRESSION_TASK_ID))

    assert spawned_database.last_iteration == 0
    for program_id, program in spawned_database.programs.items():
        expected_metrics = source_metrics_by_program[program_id][FIRST_ROBUST_REGRESSION_TASK_ID]
        assert program.metrics["combined_score"] == pytest.approx(expected_metrics["combined_score"])
        assert program.metadata["sts_warmstarted"] is True
        task_artifacts = spawned_database.get_artifacts(program_id)
        assert task_artifacts["evaluation_stage"] == "full"
        assert len(task_artifacts["task_results"]) == 1
        assert len(task_artifacts["task_results"][0]["seed_results"]) == len(FULL_EVAL_SEEDS)


def test_spawn_reevaluates_only_when_stored_task_results_are_malformed(tmp_path, monkeypatch):
    base_config_path = REPO_ROOT / "examples" / "r_robust_regression" / "config.yaml"
    evaluation_file = REPO_ROOT / "examples" / "r_robust_regression" / "evaluator.py"

    config = Config.from_yaml(base_config_path)
    config.database.db_path = None
    shared_database = ProgramDatabase(config.database)

    valid_task_results = _build_full_task_results(program_index=0)
    valid_shared_metrics = aggregate_task_results(valid_task_results)
    malformed_task_result = dict(valid_task_results[0])
    malformed_task_result["seed_results"] = malformed_task_result["seed_results"][:-1]

    shared_database.add(
        Program(
            id="program_bad_artifact",
            code="# malformed artifact\nmain <- function(X_train, y_train, X_test) list()\n",
            changes_description="bad artifact source",
            language="r",
            generation=0,
            iteration_found=0,
            metrics=valid_shared_metrics,
            metadata={"island": 0},
            artifacts_json=json.dumps(
                {
                    "task_selector": "all",
                    "evaluation_stage": "full",
                    "task_results": [
                        malformed_task_result,
                        *valid_task_results[1:],
                    ],
                }
            ),
        ),
        target_island=0,
    )

    shared_checkpoint = tmp_path / "shared_checkpoint"
    shared_database.save(str(shared_checkpoint), iteration=5)

    reevaluated_task_result = valid_task_results[0]
    reevaluation_calls: list[str] = []

    def fake_reevaluate_program_for_task(*, program, task_id, evaluation_file):
        del evaluation_file
        reevaluation_calls.append(program.id)
        assert task_id == FIRST_ROBUST_REGRESSION_TASK_ID
        return reevaluated_task_result

    monkeypatch.setattr(
        "openevolve.multi_task_shared_then_specialize.spawn._reevaluate_program_for_task",
        fake_reevaluate_program_for_task,
    )

    spawned_root = tmp_path / "spawned"
    spawn_results = spawn_task_checkpoints(
        shared_checkpoint_path=shared_checkpoint,
        output_root=spawned_root,
        base_config_path=base_config_path,
        evaluation_file=evaluation_file,
        task_ids=[FIRST_ROBUST_REGRESSION_TASK_ID],
    )

    assert reevaluation_calls == ["program_bad_artifact"]
    assert spawn_results[FIRST_ROBUST_REGRESSION_TASK_ID]["reevaluated_program_ids"] == [
        "program_bad_artifact"
    ]


def test_reevaluate_program_for_task_supports_sync_evaluate(tmp_path, monkeypatch):
    task_results = _build_full_task_results(program_index=0)
    evaluation_result = rr_evaluator.EvaluationResult(
        metrics=aggregate_task_results(task_results),
        artifacts={
            "task_selector": "all",
            "evaluation_stage": "full",
            "task_results": task_results,
        },
    )

    monkeypatch.setattr(
        mt_sts_spawn,
        "_load_evaluation_module",
        lambda _: SimpleNamespace(evaluate=lambda _: evaluation_result),
    )

    program = Program(
        id="program_sync_eval",
        code="# sync eval\nmain <- function(X_train, y_train, X_test) list()\n",
        changes_description="sync evaluator compatibility",
        language="r",
        metrics={"score": 0.0, "combined_score": 0.0},
    )

    task_result = mt_sts_spawn._reevaluate_program_for_task(
        program=program,
        task_id=FIRST_ROBUST_REGRESSION_TASK_ID,
        evaluation_file=tmp_path / "unused_evaluator.py",
    )

    assert task_result["task_id"] == FIRST_ROBUST_REGRESSION_TASK_ID
    assert task_result["metrics"]["combined_score"] == pytest.approx(
        task_result["metrics"]["score"]
    )


def test_evaluator_score_tracks_predictions_not_fake_metrics(monkeypatch):
    task = ROBUST_REGRESSION_TASK_SPECS[0]

    def fake_run(program_path, task_spec, dataset, timeout_seconds=30.0):
        del program_path, task_spec, timeout_seconds
        if dataset.base_seed == 0:
            predictions = dataset.y_test_clean_signal
            coefficients = dataset.true_coefficients
        else:
            predictions = np.full(task.n_test, 1e5, dtype=float)
            coefficients = np.full(task.n_features + 1, 1e5, dtype=float)
        return (
            {
                "predictions": predictions.tolist(),
                "coefficients": coefficients.tolist(),
                "mse": 0.0,
                "mae": 0.0,
                "r_squared": 1.0,
                "outlier_robustness": 1.0,
            },
            None,
            0.01,
        )

    monkeypatch.setattr(rr_evaluator, "run_r_program_on_dataset", fake_run)

    good_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=(0,),
        timeout_seconds=1.0,
    )
    bad_result = rr_evaluator.evaluate_one_task(
        "unused_program.r",
        task,
        base_seeds=(1,),
        timeout_seconds=1.0,
    )

    assert good_result["metrics"]["score"] > bad_result["metrics"]["score"]
    assert good_result["metrics"]["score"] > 0.5
    assert bad_result["metrics"]["score"] < 1e-4


def test_standalone_example_works_with_new_contract(monkeypatch):
    rscript = shutil.which("Rscript")
    if rscript is None:
        pytest.skip("Rscript is not installed")

    jsonlite_check = subprocess.run(
        [rscript, "-e", "quit(status=ifelse(requireNamespace('jsonlite', quietly=TRUE), 0, 1))"],
        capture_output=True,
        text=True,
    )
    if jsonlite_check.returncode != 0:
        pytest.skip("R package 'jsonlite' is not installed")

    program_path = REPO_ROOT / "examples" / "r_robust_regression" / "initial_program.r"
    monkeypatch.setenv(ROBUST_REGRESSION_TASK_SELECTOR_ENV_VAR, FIRST_ROBUST_REGRESSION_TASK_ID)

    result = rr_evaluator.evaluate(str(program_path))

    assert result.metrics["combined_score"] == pytest.approx(result.metrics["score"])
    assert result.metrics["successful_seed_count"] == len(FULL_EVAL_SEEDS)
    assert result.metrics["score"] > 0.0
