import importlib.util
import json
import os
from pathlib import Path
import sys
from dataclasses import replace

import numpy as np
import pytest

os.environ.setdefault("OPENAI_API_KEY", "test")

REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from openevolve.config import Config
from openevolve.database import Program, ProgramDatabase
from openevolve.multi_task_shared_then_specialize.registry import get_family_definition
from openevolve.multi_task_shared_then_specialize.symbolic_regression_phys_osc import (
    SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR,
    SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS,
    SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID,
    aggregate_task_results,
    build_task_result,
    extract_task_result,
    resolve_task_specs,
)
from openevolve.multi_task_shared_then_specialize.spawn import spawn_task_checkpoints


def _load_module(path: Path, module_name: str):
    spec = importlib.util.spec_from_file_location(module_name, path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Could not load module from {path}")
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module


sr_data_loader = _load_module(
    REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "data_loader.py",
    "symbolic_regression_phys_osc_mt_sts_data_loader_test",
)
sr_evaluator = _load_module(
    REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "evaluator.py",
    "symbolic_regression_phys_osc_mt_sts_evaluator_test",
)
sr_mt_sts_runner = _load_module(
    REPO_ROOT / "multi_task_shared_then_adapt" / "run_multi_task_shared_then_specialize.py",
    "symbolic_regression_phys_osc_mt_sts_runner_test",
)
sr_mt_sts_baselines_runner = _load_module(
    REPO_ROOT / "multi_task_shared_then_adapt" / "run_multi_task_shared_then_specialize_baselines.py",
    "symbolic_regression_phys_osc_mt_sts_baselines_runner_test",
)
sr_runtime_data_loader = sys.modules[sr_evaluator.load_task_data.__module__]


def _candidate_program_code(num_params: int = 10) -> str:
    return (
        "import numpy as np\n\n"
        "def func(x, params):\n"
        "    x = np.asarray(x, dtype=float)\n"
        "    params = np.asarray(params, dtype=float)\n"
        "    if params.size < 4:\n"
        "        params = np.pad(params, (0, 4 - params.size))\n"
        "    pos = x[:, 0]\n"
        "    t_val = x[:, 1]\n"
        "    vel = x[:, 2]\n"
        "    return -(params[0] * pos) - (params[1] * vel) + params[2] * np.sin(t_val) + params[3]\n\n"
        f"func.num_params = {num_params}\n\n"
        "def run_search():\n"
        "    return func\n"
    )


def _fake_raw_metrics(task_index: int, program_index: int) -> dict[str, float]:
    base = 0.05 * task_index + 0.02 * program_index
    return {
        "train_nmse": 0.10 + base,
        "test_nmse": 0.20 + base,
        "ood_nmse": 0.30 + base,
        "train_r2": 0.90 - base,
        "test_r2": 0.82 - base,
        "ood_r2": 0.74 - base,
        "successful_restarts": 5,
        "total_restarts": 6,
        "num_params_used": 4 + program_index,
        "eval_time": 0.15 + 0.02 * task_index + 0.01 * program_index,
    }


def _artifact_keys(obj):
    keys: set[str] = set()
    if isinstance(obj, dict):
        for key, value in obj.items():
            keys.add(str(key))
            keys.update(_artifact_keys(value))
    elif isinstance(obj, list):
        for value in obj:
            keys.update(_artifact_keys(value))
    return keys


def _write_generated_problem_dir(
    problem_dir: Path,
    *,
    fill_value: float = 1.0,
    input_var_names: tuple[str, ...] = ("x", "t", "v"),
    output_var_name: str = "dv_dt",
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    problem_dir.mkdir(parents=True)
    x_train = np.full((64, 3), fill_value, dtype=float)
    y_train = np.full(64, fill_value, dtype=float)
    x_test = np.full((64, 3), fill_value + 1.0, dtype=float)
    y_test = np.full(64, fill_value + 1.0, dtype=float)
    x_ood = np.full((64, 3), fill_value + 2.0, dtype=float)
    y_ood = np.full(64, fill_value + 2.0, dtype=float)
    np.save(problem_dir / "X_train_for_eval.npy", x_train)
    np.save(problem_dir / "y_train_for_eval.npy", y_train)
    np.save(problem_dir / "X_test_for_eval.npy", x_test)
    np.save(problem_dir / "y_test_for_eval.npy", y_test)
    np.save(problem_dir / "X_ood_test_for_eval.npy", x_ood)
    np.save(problem_dir / "y_ood_test_for_eval.npy", y_ood)
    (problem_dir / "problem_metadata.json").write_text(
        json.dumps(
            {
                "input_var_names": list(input_var_names),
                "output_var_name": output_var_name,
            }
        ),
        encoding="utf-8",
    )
    return x_train, y_train, x_test, y_test, x_ood, y_ood


def test_family_registry_resolves_new_family():
    family = get_family_definition("symbolic_regression_phys_osc")
    assert family.family == "symbolic_regression_phys_osc"
    assert family.task_selector_env_var == "SYMBOLIC_REGRESSION_PHYS_OSC_TASK_ID"
    assert sorted(family.tasks_by_id) == ["sr_po11", "sr_po17", "sr_po30", "sr_po37"]


def test_resolve_task_specs_all_returns_expected_tasks():
    task_specs = resolve_task_specs("all")
    assert [task.task_id for task in task_specs] == ["sr_po11", "sr_po17", "sr_po30", "sr_po37"]


def test_synthetic_fixture_mode_loads_without_benchmark_files(monkeypatch):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", raising=False)

    task_data = sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po11"])

    assert task_data["metadata"]["data_source_mode"] == "synthetic_fixture"
    for split in ("train", "test", "ood"):
        x_array, y_array = task_data[split]
        assert x_array.shape[1] == 3
        assert x_array.shape[0] >= 64
        assert y_array.shape == (x_array.shape[0],)


def test_generated_problem_directory_loading_works(monkeypatch, tmp_path):
    sr_data_loader._TASK_DATA_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", str(tmp_path))

    problem_dir = tmp_path / "phys_osc" / "PO11"
    x_train, _, _, _, _, y_ood = _write_generated_problem_dir(problem_dir)

    task_data = sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po11"])

    assert task_data["metadata"]["data_source_mode"] == "generated_problem_dir"
    assert np.array_equal(task_data["train"][0], x_train)
    assert np.array_equal(task_data["ood"][1], y_ood)


def test_local_problem_cache_root_is_reused_automatically(monkeypatch, tmp_path):
    sr_data_loader._TASK_DATA_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", raising=False)
    monkeypatch.setattr(sr_data_loader, "_dataset_cache_root", lambda: tmp_path / "_problem_cache")

    problem_dir = tmp_path / "_problem_cache" / "phys_osc" / "PO17"
    problem_dir.mkdir(parents=True)
    x_train = np.arange(192, dtype=float).reshape(64, 3)
    y_train = np.linspace(-1.0, 1.0, 64, dtype=float)
    x_test = x_train + 1.0
    y_test = y_train + 1.0
    x_ood = x_train + 2.0
    y_ood = y_train + 2.0
    np.save(problem_dir / "X_train_for_eval.npy", x_train)
    np.save(problem_dir / "y_train_for_eval.npy", y_train)
    np.save(problem_dir / "X_test_for_eval.npy", x_test)
    np.save(problem_dir / "y_test_for_eval.npy", y_test)
    np.save(problem_dir / "X_ood_test_for_eval.npy", x_ood)
    np.save(problem_dir / "y_ood_test_for_eval.npy", y_ood)
    (problem_dir / "problem_metadata.json").write_text(
        json.dumps(
            {
                "input_var_names": ["x", "t", "v"],
                "output_var_name": "dv_dt",
            }
        ),
        encoding="utf-8",
    )

    task_data = sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po17"])

    assert task_data["metadata"]["data_source_mode"] == "generated_problem_dir"
    assert np.array_equal(task_data["train"][0], x_train)
    assert np.array_equal(task_data["test"][1], y_test)


def test_malformed_implicit_cache_does_not_shadow_later_valid_root(monkeypatch, tmp_path):
    sr_data_loader._TASK_DATA_CACHE.clear()
    sr_data_loader._DATAMODULE_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", raising=False)
    monkeypatch.setattr(
        sr_data_loader,
        "_candidate_problem_root_entries",
        lambda: [
            ("local generated-problem cache", tmp_path / "_problem_cache"),
            ("repo-root problems", tmp_path / "problems"),
        ],
    )
    monkeypatch.setattr(
        sr_data_loader,
        "_load_from_datamodule",
        lambda task_spec: (_ for _ in ()).throw(AssertionError("benchmark fallback should not run")),
    )

    stale_problem_dir = tmp_path / "_problem_cache" / "phys_osc" / "PO11"
    stale_problem_dir.mkdir(parents=True)
    np.save(stale_problem_dir / "X_train_for_eval.npy", np.ones((64, 3), dtype=float))
    np.save(stale_problem_dir / "y_train_for_eval.npy", np.ones(64, dtype=float))

    valid_problem_dir = tmp_path / "problems" / "phys_osc" / "PO11"
    x_train, _, _, _, _, y_ood = _write_generated_problem_dir(valid_problem_dir, fill_value=5.0)

    task_data = sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po11"])

    assert task_data["metadata"]["data_source_mode"] == "generated_problem_dir"
    assert np.array_equal(task_data["train"][0], x_train)
    assert np.array_equal(task_data["ood"][1], y_ood)


def test_explicit_root_stays_strict_even_if_later_implicit_root_is_valid(monkeypatch, tmp_path):
    sr_data_loader._TASK_DATA_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", str(tmp_path / "explicit_root"))
    monkeypatch.setattr(
        sr_data_loader,
        "_dataset_cache_root",
        lambda: tmp_path / "_problem_cache",
    )

    explicit_problem_dir = tmp_path / "explicit_root" / "phys_osc" / "PO11"
    explicit_problem_dir.mkdir(parents=True)
    np.save(explicit_problem_dir / "X_train_for_eval.npy", np.ones((64, 3), dtype=float))
    np.save(explicit_problem_dir / "y_train_for_eval.npy", np.ones(64, dtype=float))

    _write_generated_problem_dir(tmp_path / "_problem_cache" / "phys_osc" / "PO11", fill_value=7.0)

    with pytest.raises(sr_data_loader.GeneratedProblemMissingError, match="missing required files"):
        sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po11"])


def test_generated_problem_directory_without_metadata_fails_loudly(monkeypatch, tmp_path):
    sr_data_loader._TASK_DATA_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", str(tmp_path))

    problem_dir = tmp_path / "phys_osc" / "PO30"
    problem_dir.mkdir(parents=True)
    x_data = np.ones((64, 3), dtype=float)
    y_data = np.ones(64, dtype=float)
    for split_prefix in ("train", "test", "ood_test"):
        np.save(problem_dir / f"X_{split_prefix}_for_eval.npy", x_data)
        np.save(problem_dir / f"y_{split_prefix}_for_eval.npy", y_data)

    with pytest.raises(sr_data_loader.ProblemMetadataError, match="missing valid metadata"):
        sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po30"])


def test_generated_problem_directory_with_missing_metadata_key_fails_loudly(monkeypatch, tmp_path):
    sr_data_loader._TASK_DATA_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", str(tmp_path))

    problem_dir = tmp_path / "phys_osc" / "PO30"
    problem_dir.mkdir(parents=True)
    x_data = np.ones((64, 3), dtype=float)
    y_data = np.ones(64, dtype=float)
    for split_prefix in ("train", "test", "ood_test"):
        np.save(problem_dir / f"X_{split_prefix}_for_eval.npy", x_data)
        np.save(problem_dir / f"y_{split_prefix}_for_eval.npy", y_data)
    (problem_dir / "problem_metadata.json").write_text(
        json.dumps({"input_var_names": ["x", "t", "v"]}),
        encoding="utf-8",
    )

    with pytest.raises(
        sr_data_loader.ProblemMetadataError,
        match="missing required key 'output_var_name'",
    ):
        sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po30"])


def test_benchmark_dependency_missing_raises_actionable_loader_error(monkeypatch, tmp_path):
    sr_data_loader._TASK_DATA_CACHE.clear()
    sr_data_loader._DATAMODULE_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", raising=False)
    invalid_problem_dir = tmp_path / "_problem_cache" / "phys_osc" / "PO11"
    invalid_problem_dir.mkdir(parents=True)
    np.save(invalid_problem_dir / "X_train_for_eval.npy", np.ones((64, 3), dtype=float))
    np.save(invalid_problem_dir / "y_train_for_eval.npy", np.ones(64, dtype=float))
    monkeypatch.setattr(
        sr_data_loader,
        "_candidate_problem_root_entries",
        lambda: [("local generated-problem cache", tmp_path / "_problem_cache")],
    )
    monkeypatch.setattr(sr_data_loader, "get_datamodule", None)

    with pytest.raises(
        sr_data_loader.BenchmarkDependencyError,
        match="(?s)could not be reused.*bench\\.datamodules\\.get_datamodule",
    ):
        sr_data_loader.load_task_data(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po11"])


def test_worker_classifies_benchmark_dependency_failure(monkeypatch):
    monkeypatch.setattr(
        sr_evaluator,
        "load_task_data",
        lambda task: (_ for _ in ()).throw(
            sr_runtime_data_loader.BenchmarkDependencyError(
                "bench.datamodules.get_datamodule is unavailable"
            )
        ),
    )

    payload = sr_evaluator.run_task_worker(
        program_path=str(REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "initial_program.py"),
        task_id="sr_po11",
        stage1=False,
    )

    assert payload["task_result"]["failure_kind"] == "benchmark_dependency_missing"
    assert payload["task_result"]["failure_stage"] == "load_task_data"
    assert payload["task_artifacts"]["failure_kind"] == "benchmark_dependency_missing"
    assert payload["task_artifacts"]["data_source_mode"] == "benchmark_generation_fallback"


def test_preflight_fails_loudly_when_real_assets_are_missing(monkeypatch, tmp_path):
    sr_runtime_data_loader._TASK_DATA_CACHE.clear()
    sr_runtime_data_loader._DATAMODULE_CACHE.clear()
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", raising=False)
    monkeypatch.delenv("SYMBOLIC_REGRESSION_PROBLEMS_ROOT", raising=False)
    monkeypatch.setattr(
        sr_runtime_data_loader,
        "_candidate_problem_root_entries",
        lambda: [("local generated-problem cache", tmp_path / "_problem_cache")],
    )
    monkeypatch.setattr(sr_runtime_data_loader, "get_datamodule", None)

    with pytest.raises(Exception, match="symbolic-regression benchmark assets are unavailable"):
        sr_evaluator.preflight_check_symbolic_regression_phys_osc(task_ids=["sr_po11", "sr_po17"])


def test_preflight_passes_in_synthetic_fixture_mode(monkeypatch):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")

    summary = sr_evaluator.preflight_check_symbolic_regression_phys_osc(task_ids=["sr_po11", "sr_po17"])

    assert summary["synthetic_fixture_enabled"] is True
    assert summary["data_source_mode_summary"] == {"synthetic_fixture": 2}


def test_shared_mode_returns_aggregate_metrics_and_task_artifacts(monkeypatch):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setenv(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR, "all")

    result = sr_evaluator.evaluate(
        str(REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "initial_program.py")
    )

    assert result.metrics["task_count"] == pytest.approx(4.0)
    assert result.metrics["score"] == pytest.approx(result.metrics["combined_score"])
    assert result.artifacts["evaluation_mode"] == "shared"
    assert len(result.artifacts["task_results"]) == 4
    assert set(result.artifacts["selected_task_ids"]) == {
        "sr_po11",
        "sr_po17",
        "sr_po30",
        "sr_po37",
    }
    assert result.artifacts["failure_kind_counts"] == {"none": 4}
    assert result.artifacts["data_source_mode_summary"] == {"synthetic_fixture": 4}
    for task_result in result.artifacts["task_results"]:
        assert task_result["metrics"]["score"] == pytest.approx(
            task_result["metrics"]["combined_score"]
        )
        assert task_result["spec"]["input_var_names"] == ["x", "t", "v"]


def test_task_specific_mode_returns_exactly_one_task(monkeypatch):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setenv(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR, "sr_po17")

    result = sr_evaluator.evaluate(
        str(REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "initial_program.py")
    )

    assert result.artifacts["evaluation_mode"] == "task_specific"
    assert result.artifacts["task_selector"] == "sr_po17"
    assert len(result.artifacts["task_results"]) == 1
    assert result.artifacts["task_results"][0]["task_id"] == "sr_po17"
    assert result.metrics["score"] == pytest.approx(result.metrics["combined_score"])


def test_evaluator_accepts_run_search_candidate(monkeypatch, tmp_path):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setenv(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR, "sr_po11")

    candidate_path = tmp_path / "candidate_run_search.py"
    candidate_path.write_text(_candidate_program_code(num_params=6), encoding="utf-8")

    result = sr_evaluator.evaluate(str(candidate_path))

    assert result.artifacts["task_results"][0]["error"] is None
    assert result.artifacts["task_results"][0]["failure_kind"] == "none"
    assert result.metrics["num_params_used"] == pytest.approx(6.0)


def test_evaluator_accepts_direct_func_candidate(monkeypatch, tmp_path):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setenv(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR, "sr_po30")

    candidate_path = tmp_path / "candidate_direct_func.py"
    candidate_path.write_text(
        (
            "import numpy as np\n\n"
            "def func(x, params):\n"
            "    x = np.asarray(x, dtype=float)\n"
            "    params = np.asarray(params, dtype=float)\n"
            "    if params.size < 3:\n"
            "        params = np.pad(params, (0, 3 - params.size))\n"
            "    return -(params[0] * x[:, 0]) - (params[1] * x[:, 2]) + params[2] * np.sin(x[:, 1])\n\n"
            "func.num_params = 3\n"
        ),
        encoding="utf-8",
    )

    result = sr_evaluator.evaluate(str(candidate_path))

    assert result.artifacts["task_results"][0]["error"] is None
    assert result.artifacts["task_results"][0]["failure_kind"] == "none"
    assert result.metrics["num_params_used"] == pytest.approx(3.0)


def test_extract_task_result_returns_none_for_malformed_stored_metrics():
    assert extract_task_result(
        {
            "task_results": [
                {
                    "task_id": "sr_po11",
                    "metrics": "bad",
                }
            ]
        },
        "sr_po11",
    ) is None

    assert extract_task_result(
        {
            "task_results": [
                {
                    "task_id": "sr_po11",
                    "metrics": {
                        "test_nmse": 0.2,
                        "ood_nmse": 0.3,
                        "score": float("nan"),
                        "combined_score": 0.5,
                    },
                }
            ]
        },
        "sr_po11",
    ) is None


def test_spawn_uses_stored_task_results_without_reevaluation(monkeypatch, tmp_path):
    base_config_path = REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "config.yaml"
    evaluation_file = REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "evaluator.py"
    initial_program = REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "initial_program.py"

    config = Config.from_yaml(base_config_path)
    config.database.db_path = None
    shared_database = ProgramDatabase(config.database)

    source_metrics_by_program: dict[str, dict[str, dict[str, float]]] = {}
    for program_index in range(2):
        task_results = [
            build_task_result(
                task,
                raw_metrics=_fake_raw_metrics(task.task_index, program_index),
                data_source_mode="synthetic_fixture",
            )
            for task in SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS
        ]
        metrics = aggregate_task_results(task_results)
        program = Program(
            id=f"program_{program_index}",
            code=_candidate_program_code(),
            changes_description=f"program {program_index}",
            language="python",
            generation=program_index,
            iteration_found=program_index,
            metrics=metrics,
            metadata={"island": program_index % config.database.num_islands},
            artifacts_json=json.dumps(
                {
                    "task_selector": "all",
                    "evaluation_stage": "full",
                    "task_results": task_results,
                }
            ),
        )
        shared_database.add(program, target_island=program.metadata["island"])
        source_metrics_by_program[program.id] = {
            task_result["task_id"]: task_result["metrics"] for task_result in task_results
        }

    shared_checkpoint = tmp_path / "shared_checkpoint"
    shared_database.save(str(shared_checkpoint), iteration=5)

    def fail_if_reevaluated(**kwargs):
        raise AssertionError("Spawn should not reevaluate when valid task_results are stored")

    monkeypatch.setattr(
        "openevolve.multi_task_shared_then_specialize.spawn._reevaluate_program_for_task",
        fail_if_reevaluated,
    )

    spawned_root = tmp_path / "spawned"
    spawn_results = spawn_task_checkpoints(
        shared_checkpoint_path=shared_checkpoint,
        output_root=spawned_root,
        base_config_path=base_config_path,
        evaluation_file=evaluation_file,
        family="symbolic_regression_phys_osc",
        task_ids=["sr_po11"],
        initial_program=initial_program,
    )

    assert "sr_po11" in spawn_results
    spawned_checkpoint = spawned_root / "sr_po11"
    assert (spawned_checkpoint / "metadata.json").is_file()
    assert (spawned_checkpoint / "best_program_info.json").is_file()

    spawned_config = Config.from_yaml(base_config_path)
    spawned_config.database.db_path = None
    spawned_database = ProgramDatabase(spawned_config.database)
    spawned_database.load(str(spawned_checkpoint))

    assert spawned_database.last_iteration == 0
    assert len(spawned_database.programs) == 2
    for program_id, program in spawned_database.programs.items():
        expected_metrics = source_metrics_by_program[program_id]["sr_po11"]
        assert program.metrics["combined_score"] == pytest.approx(program.metrics["score"])
        assert program.metrics["combined_score"] == pytest.approx(expected_metrics["combined_score"])
        assert program.metadata["sts_warmstarted"] is True
        assert program.metadata["sts_target_task_id"] == "sr_po11"
        task_artifacts = spawned_database.get_artifacts(program_id)
        assert task_artifacts["task_selector"] == "sr_po11"
        assert len(task_artifacts["task_results"]) == 1
        assert task_artifacts["task_results"][0]["task_id"] == "sr_po11"


def test_hard_timeout_returns_finite_failure_metrics(monkeypatch, tmp_path):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    candidate_path = tmp_path / "candidate_hangs.py"
    candidate_path.write_text(
        (
            "import numpy as np\n\n"
            "def func(x, params):\n"
            "    while True:\n"
            "        pass\n\n"
            "func.num_params = 2\n\n"
            "def run_search():\n"
            "    return func\n"
        ),
        encoding="utf-8",
    )

    fast_timeout_task = replace(
        SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID["sr_po11"],
        timeout_seconds_full=1.0,
    )

    task_result, task_artifacts = sr_evaluator.run_task_with_hard_timeout(
        program_path=str(candidate_path),
        task=fast_timeout_task,
        stage1=False,
    )

    assert task_result["error"] is not None
    assert task_result["metrics"]["score"] == pytest.approx(0.0)
    assert task_result["metrics"]["combined_score"] == pytest.approx(0.0)
    assert np.isfinite(task_result["metrics"]["test_nmse"])
    assert np.isfinite(task_result["metrics"]["ood_nmse"])
    assert task_result["failure_kind"] == "worker_timeout"
    assert task_result["failure_stage"] == "worker_subprocess"
    assert task_artifacts["successful_restarts"] == 0


def test_wrong_shaped_predictions_fail_cleanly(monkeypatch, tmp_path):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setenv(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR, "sr_po11")
    candidate_path = tmp_path / "candidate_bad_shape.py"
    candidate_path.write_text(
        (
            "import numpy as np\n\n"
            "def func(x, params):\n"
            "    return np.zeros((x.shape[0], 2), dtype=float)\n\n"
            "func.num_params = 2\n\n"
            "def run_search():\n"
            "    return func\n"
        ),
        encoding="utf-8",
    )

    result = sr_evaluator.evaluate(str(candidate_path))

    assert result.metrics["score"] == pytest.approx(0.0)
    assert result.artifacts["task_results"][0]["error"] is not None
    assert result.artifacts["task_results"][0]["failure_kind"] == "prediction_invalid"


def test_non_finite_predictions_fail_cleanly(monkeypatch, tmp_path):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setenv(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR, "sr_po11")
    candidate_path = tmp_path / "candidate_non_finite.py"
    candidate_path.write_text(
        (
            "import numpy as np\n\n"
            "def func(x, params):\n"
            "    return np.full(x.shape[0], np.nan, dtype=float)\n\n"
            "func.num_params = 2\n\n"
            "def run_search():\n"
            "    return func\n"
        ),
        encoding="utf-8",
    )

    result = sr_evaluator.evaluate(str(candidate_path))

    assert result.metrics["score"] == pytest.approx(0.0)
    assert result.artifacts["task_results"][0]["error"] is not None
    assert result.artifacts["task_results"][0]["failure_kind"] == "prediction_invalid"


def test_optimization_failure_is_distinct_from_data_loading(monkeypatch):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setattr(
        sr_evaluator,
        "_run_optimizer",
        lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("optimizer failed")),
    )

    payload = sr_evaluator.run_task_worker(
        program_path=str(REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "initial_program.py"),
        task_id="sr_po11",
        stage1=False,
    )

    assert payload["task_result"]["failure_kind"] == "optimization_failed"
    assert payload["task_result"]["failure_stage"] == "select_best_restart"
    assert payload["task_artifacts"]["failure_kind"] == "optimization_failed"


def test_artifacts_do_not_contain_raw_arrays_or_ground_truth_equations(monkeypatch):
    monkeypatch.setenv("SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE", "1")
    monkeypatch.setenv(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR, "all")

    result = sr_evaluator.evaluate(
        str(REPO_ROOT / "examples" / "symbolic_regression_phys_osc_mt_sts" / "initial_program.py")
    )

    serialized = json.dumps(result.artifacts, sort_keys=True)
    keys = _artifact_keys(result.artifacts)
    forbidden_keys = {
        "train",
        "test",
        "ood",
        "X_train",
        "y_train",
        "X_test",
        "y_test",
        "X_ood_test",
        "y_ood_test",
        "expression",
        "gt_equation",
        "ground_truth_equation",
    }
    assert not (keys & forbidden_keys)
    assert "Acceleration in Nonl-linear Harmonic Oscillator" not in serialized
    assert "expression" not in serialized


def test_shared_runner_fails_fast_before_launching_zero_score_run(monkeypatch, tmp_path):
    monkeypatch.setattr(
        sr_mt_sts_runner,
        "run_command",
        lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("run_command should not be called")),
    )
    monkeypatch.setattr(
        sr_mt_sts_runner,
        "run_mt_sts_family_preflight",
        lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("preflight failed")),
    )
    monkeypatch.setattr(
        sys,
        "argv",
        [
            "run_multi_task_shared_then_specialize.py",
            "--manifest",
            str(REPO_ROOT / "multi_task_shared_then_adapt" / "symbolic_regression_phys_osc_mt_sts.yaml"),
            "--output-root",
            str(tmp_path),
        ],
    )

    with pytest.raises(SystemExit, match="preflight failed"):
        sr_mt_sts_runner.main()


def test_baseline_runner_fails_fast_before_launching_zero_score_run(monkeypatch, tmp_path):
    monkeypatch.setattr(
        sr_mt_sts_baselines_runner,
        "run_command",
        lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("run_command should not be called")),
    )
    monkeypatch.setattr(
        sr_mt_sts_baselines_runner,
        "run_mt_sts_family_preflight",
        lambda *args, **kwargs: (_ for _ in ()).throw(RuntimeError("preflight failed")),
    )
    monkeypatch.setattr(
        sys,
        "argv",
        [
            "run_multi_task_shared_then_specialize_baselines.py",
            "--manifest",
            str(REPO_ROOT / "multi_task_shared_then_adapt" / "symbolic_regression_phys_osc_mt_sts.yaml"),
            "--output-root",
            str(tmp_path),
        ],
    )

    with pytest.raises(SystemExit, match="preflight failed"):
        sr_mt_sts_baselines_runner.main()
