"""Evaluator for the MT-STS symbolic-regression physics-oscillator family."""

from __future__ import annotations

import argparse
from collections import Counter
import importlib.util
import json
import math
import os
from pathlib import Path
import subprocess
import sys
import tempfile
import time
import traceback
from typing import Any, Callable, Dict, Mapping

import numpy as np
from types import SimpleNamespace

try:
    from scipy.optimize import minimize as scipy_minimize
except ImportError:  # pragma: no cover - depends on local environment
    scipy_minimize = None

REPO_ROOT = Path(__file__).resolve().parents[2]
EXAMPLE_DIR = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
if str(EXAMPLE_DIR) not in sys.path:
    sys.path.insert(0, str(EXAMPLE_DIR))

from data_loader import (
    SymbolicRegressionDataError,
    load_task_data,
    validate_symbolic_regression_phys_osc_assets,
)
from openevolve.evaluation_result import EvaluationResult
from openevolve.multi_task_shared_then_specialize.symbolic_regression_phys_osc import (
    DEFAULT_FULL_TIMEOUT_SECONDS,
    DEFAULT_STAGE1_TIMEOUT_SECONDS,
    SYMBOLIC_REGRESSION_PHYS_OSC_SHARED_SELECTOR,
    SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR,
    SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS,
    SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID,
    SymbolicRegressionPhysOscTaskSpec,
    aggregate_task_results,
    build_task_result,
    resolve_task_specs,
)


class EvaluatorTaskError(RuntimeError):
    """Structured task-evaluation failure for symbolic-regression MT-STS."""

    def __init__(
        self,
        message: str,
        *,
        failure_kind: str,
        failure_stage: str,
        data_source_mode: str = "unknown",
    ) -> None:
        super().__init__(message)
        self.failure_kind = str(failure_kind)
        self.failure_stage = str(failure_stage)
        self.data_source_mode = str(data_source_mode)


class PredictionInvalidError(EvaluatorTaskError):
    def __init__(self, message: str, *, failure_stage: str) -> None:
        super().__init__(
            message,
            failure_kind="prediction_invalid",
            failure_stage=failure_stage,
        )


def _json_default(value: Any):
    if isinstance(value, np.ndarray):
        return value.tolist()
    if isinstance(value, (np.integer,)):
        return int(value)
    if isinstance(value, (np.floating,)):
        return float(value)
    raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")


def _load_program_module(program_path: str):
    module_name = f"symbolic_regression_phys_osc_mt_sts_{hash(Path(program_path).resolve())}"
    spec = importlib.util.spec_from_file_location(module_name, program_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Could not load program from {program_path}")
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


def _resolve_model_func(module: Any) -> Callable[[np.ndarray, np.ndarray], np.ndarray]:
    run_search = getattr(module, "run_search", None)
    if callable(run_search):
        candidate = run_search()
        if not callable(candidate):
            raise TypeError("run_search() must return a callable func(x, params)")
        return candidate

    func = getattr(module, "func", None)
    if callable(func):
        return func
    raise AttributeError("Program must define run_search() -> func or a callable func(x, params)")


def _resolve_num_params(model_func: Callable[..., Any]) -> int:
    raw_value = getattr(model_func, "num_params", 10)
    try:
        numeric = int(raw_value)
    except (TypeError, ValueError):
        numeric = 10
    return max(1, min(10, numeric))


def _predict(
    model_func: Callable[..., Any],
    x_array: np.ndarray,
    params: np.ndarray,
    *,
    failure_stage: str,
) -> np.ndarray:
    predictions = np.asarray(model_func(x_array, params), dtype=float).reshape(-1)
    if predictions.shape != (x_array.shape[0],):
        raise PredictionInvalidError(
            f"func(x, params) must return shape {(x_array.shape[0],)}, got {predictions.shape}",
            failure_stage=failure_stage,
        )
    if not np.all(np.isfinite(predictions)):
        raise PredictionInvalidError(
            "func(x, params) returned non-finite predictions",
            failure_stage=failure_stage,
        )
    return predictions


def objective_function(
    params: np.ndarray,
    model_func: Callable[..., Any],
    x_array: np.ndarray,
    y_array: np.ndarray,
) -> float:
    """Objective for deterministic multi-restart BFGS fitting."""
    try:
        predictions = _predict(
            model_func,
            x_array,
            np.asarray(params, dtype=float),
            failure_stage="fit_restart",
        )
    except Exception:
        return float("inf")
    mse = float(np.mean((y_array - predictions) ** 2))
    return mse if math.isfinite(mse) else float("inf")


def _fallback_minimize(
    objective_fn: Callable[..., float],
    x0: np.ndarray,
    *,
    args: tuple[Any, ...],
    options: Mapping[str, Any] | None,
):
    """Deterministic local search fallback when scipy is unavailable."""
    maxiter = int((options or {}).get("maxiter", 100))
    current = np.asarray(x0, dtype=float).copy()
    current_value = float(objective_fn(current, *args))
    step_scale = 0.25

    for _ in range(max(1, maxiter)):
        improved = False
        for dim in range(current.size):
            for direction in (-1.0, 1.0):
                candidate = current.copy()
                candidate[dim] += direction * step_scale
                candidate_value = float(objective_fn(candidate, *args))
                if candidate_value < current_value:
                    current = candidate
                    current_value = candidate_value
                    improved = True
        step_scale *= 0.97
        if not improved and step_scale < 1.0e-4:
            break

    return SimpleNamespace(
        x=current,
        fun=current_value,
        success=math.isfinite(current_value),
        message=(
            "Fallback local-search optimizer used because scipy is not installed"
        ),
    )


def _run_optimizer(
    initial_params: np.ndarray,
    model_func: Callable[..., Any],
    x_train: np.ndarray,
    y_train: np.ndarray,
    *,
    maxiter: int,
):
    if scipy_minimize is not None:
        return scipy_minimize(
            objective_function,
            initial_params,
            args=(model_func, x_train, y_train),
            method="BFGS",
            options={
                "maxiter": int(maxiter),
                "disp": False,
            },
        )
    return _fallback_minimize(
        objective_function,
        initial_params,
        args=(model_func, x_train, y_train),
        options={"maxiter": int(maxiter)},
    )


def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> Dict[str, float]:
    """Compute finite regression metrics for one dataset split."""
    y_pred = np.asarray(y_pred, dtype=float).reshape(-1)
    y_true = np.asarray(y_true, dtype=float).reshape(-1)
    mse = float(np.mean((y_true - y_pred) ** 2))
    var_y = float(np.var(y_true))
    if var_y > 0.0:
        nmse = mse / var_y
    else:
        nmse = 0.0 if mse == 0.0 else float("inf")

    sse = float(np.sum((y_true - y_pred) ** 2))
    sst = float(np.sum((y_true - np.mean(y_true)) ** 2))
    if sst > 0.0:
        r2 = 1.0 - (sse / sst)
    else:
        r2 = 1.0 if sse == 0.0 else -float("inf")

    return {
        "mse": mse,
        "nmse": nmse,
        "r2": r2,
    }


def _stage_config(
    task: SymbolicRegressionPhysOscTaskSpec,
    *,
    stage1: bool,
) -> dict[str, float | int]:
    if stage1:
        return {
            "num_restarts": task.num_restarts_stage1,
            "maxiter": task.maxiter_stage1,
            "timeout_seconds": task.timeout_seconds_stage1,
        }
    return {
        "num_restarts": task.num_restarts_full,
        "maxiter": task.maxiter_full,
        "timeout_seconds": task.timeout_seconds_full,
    }


def _failure_details_from_exception(
    exc: Exception,
    *,
    default_kind: str,
    default_stage: str,
    default_data_source_mode: str = "unknown",
) -> dict[str, str]:
    raw_data_source_mode = getattr(exc, "data_source_mode", None)
    return {
        "failure_kind": str(getattr(exc, "failure_kind", default_kind) or default_kind),
        "failure_stage": str(getattr(exc, "failure_stage", default_stage) or default_stage),
        "data_source_mode": str(
            default_data_source_mode
            if raw_data_source_mode in (None, "", "unknown")
            else raw_data_source_mode
        ),
        "message": f"{type(exc).__name__}: {exc}",
    }


def _build_failure_payload(
    task: SymbolicRegressionPhysOscTaskSpec,
    *,
    raw_metrics: Mapping[str, Any] | None,
    error: str,
    timeout_seconds: float,
    data_source_mode: str,
    failure_kind: str,
    failure_stage: str,
    task_artifacts: Mapping[str, Any],
) -> dict[str, Any]:
    task_result = build_task_result(
        task,
        raw_metrics=raw_metrics,
        error=error,
        timeout_seconds=timeout_seconds,
        data_source_mode=data_source_mode,
        failure_kind=failure_kind,
        failure_stage=failure_stage,
    )
    return {
        "task_result": task_result,
        "task_artifacts": {
            "task_id": task.task_id,
            "failure_kind": failure_kind,
            "failure_stage": failure_stage,
            "data_source_mode": data_source_mode,
            **dict(task_artifacts),
        },
    }


def _final_restart_failure_details(
    failed_restart_details: list[dict[str, str]],
) -> tuple[str, str, str]:
    if failed_restart_details and all(
        detail["failure_kind"] == "prediction_invalid" for detail in failed_restart_details
    ):
        sample = failed_restart_details[0]
        return (
            "prediction_invalid",
            sample["failure_stage"],
            "All optimization restarts produced invalid candidate predictions",
        )
    return (
        "optimization_failed",
        "select_best_restart",
        "All optimization restarts failed",
    )


def run_task_worker(
    *,
    program_path: str,
    task_id: str,
    stage1: bool,
) -> dict[str, Any]:
    """Evaluate one task inside a dedicated subprocess worker."""
    start_time = time.perf_counter()
    task = SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID[task_id]
    stage_config = _stage_config(task, stage1=stage1)
    selected_timeout_seconds = float(stage_config["timeout_seconds"])
    total_restarts = int(stage_config["num_restarts"])
    num_params_used = 10
    data_source_mode = "unknown"

    try:
        task_data = load_task_data(task)
        data_source_mode = str(task_data["metadata"]["data_source_mode"])
        module = _load_program_module(program_path)
        model_func = _resolve_model_func(module)
        num_params = _resolve_num_params(model_func)
        num_params_used = num_params

        x_train, y_train = task_data["train"]
        x_test, y_test = task_data["test"]
        x_ood, y_ood = task_data["ood"]

        best_result: dict[str, Any] | None = None
        best_train_mse = float("inf")
        successful_restarts = 0
        failed_restart_messages: list[str] = []
        failed_restart_details: list[dict[str, str]] = []

        for seed in range(total_restarts):
            rng = np.random.default_rng(seed)
            initial_params = rng.normal(loc=0.0, scale=0.35, size=num_params)
            try:
                optimization_result = _run_optimizer(
                    initial_params,
                    model_func,
                    x_train,
                    y_train,
                    maxiter=int(stage_config["maxiter"]),
                )
                candidate_params = np.asarray(optimization_result.x, dtype=float).reshape(-1)
                train_predictions = _predict(
                    model_func,
                    x_train,
                    candidate_params,
                    failure_stage="evaluate_train",
                )
                train_metrics = compute_regression_metrics(train_predictions, y_train)
                train_mse = float(train_metrics["mse"])
                if not math.isfinite(train_mse):
                    raise EvaluatorTaskError(
                        "non-finite train MSE",
                        failure_kind="optimization_failed",
                        failure_stage="evaluate_train",
                        data_source_mode=data_source_mode,
                    )
                successful_restarts += 1
                if train_mse < best_train_mse:
                    best_train_mse = train_mse
                    best_result = {
                        "best_restart_seed": seed,
                        "params": candidate_params,
                        "train_predictions": train_predictions,
                        "train_metrics": train_metrics,
                    }
            except Exception as exc:
                failure_details = _failure_details_from_exception(
                    exc,
                    default_kind="optimization_failed",
                    default_stage="fit_restart",
                    default_data_source_mode=data_source_mode,
                )
                failed_restart_messages.append(
                    f"seed={seed}: {failure_details['message']}"
                )
                failed_restart_details.append(
                    {
                        "seed": str(seed),
                        "failure_kind": failure_details["failure_kind"],
                        "failure_stage": failure_details["failure_stage"],
                        "message": failure_details["message"],
                    }
                )

        eval_time = time.perf_counter() - start_time

        if best_result is None:
            failure_kind, failure_stage, failure_message = _final_restart_failure_details(
                failed_restart_details
            )
            return _build_failure_payload(
                task,
                raw_metrics={
                    "successful_restarts": 0,
                    "total_restarts": total_restarts,
                    "num_params_used": num_params_used,
                    "eval_time": eval_time,
                },
                error=f"{failure_message} across {total_restarts} restarts",
                timeout_seconds=selected_timeout_seconds,
                data_source_mode=data_source_mode,
                failure_kind=failure_kind,
                failure_stage=failure_stage,
                task_artifacts={
                    "best_restart_seed": None,
                    "successful_restarts": 0,
                    "total_restarts": total_restarts,
                    "failed_restart_messages": failed_restart_messages,
                    "failed_restart_details": failed_restart_details,
                },
            )

        best_params = best_result["params"]
        test_predictions = _predict(
            model_func,
            x_test,
            best_params,
            failure_stage="evaluate_test",
        )
        ood_predictions = _predict(
            model_func,
            x_ood,
            best_params,
            failure_stage="evaluate_ood",
        )
        test_metrics = compute_regression_metrics(test_predictions, y_test)
        ood_metrics = compute_regression_metrics(ood_predictions, y_ood)

        raw_metrics = {
            "train_nmse": best_result["train_metrics"]["nmse"],
            "test_nmse": test_metrics["nmse"],
            "ood_nmse": ood_metrics["nmse"],
            "train_r2": best_result["train_metrics"]["r2"],
            "test_r2": test_metrics["r2"],
            "ood_r2": ood_metrics["r2"],
            "successful_restarts": successful_restarts,
            "total_restarts": total_restarts,
            "num_params_used": num_params,
            "eval_time": eval_time,
        }
        task_result = build_task_result(
            task,
            raw_metrics=raw_metrics,
            error=None,
            timeout_seconds=selected_timeout_seconds,
            data_source_mode=data_source_mode,
            failure_kind="none",
            failure_stage="none",
        )
        return {
            "task_result": task_result,
            "task_artifacts": {
                "task_id": task_id,
                "best_restart_seed": int(best_result["best_restart_seed"]),
                "successful_restarts": successful_restarts,
                "total_restarts": total_restarts,
                "failed_restart_messages": failed_restart_messages,
                "failed_restart_details": failed_restart_details,
                "data_source_mode": data_source_mode,
                "failure_kind": "none",
                "failure_stage": "none",
            },
        }
    except Exception as exc:
        eval_time = time.perf_counter() - start_time
        failure_details = _failure_details_from_exception(
            exc,
            default_kind="worker_exception",
            default_stage="unknown",
            default_data_source_mode=data_source_mode,
        )
        return _build_failure_payload(
            task,
            raw_metrics={
                "successful_restarts": 0,
                "total_restarts": total_restarts,
                "num_params_used": num_params_used,
                "eval_time": eval_time,
            },
            error=failure_details["message"],
            timeout_seconds=selected_timeout_seconds,
            data_source_mode=failure_details["data_source_mode"],
            failure_kind=failure_details["failure_kind"],
            failure_stage=failure_details["failure_stage"],
            task_artifacts={
                "best_restart_seed": None,
                "successful_restarts": 0,
                "total_restarts": total_restarts,
                "failed_restart_messages": [failure_details["message"]],
                "failed_restart_details": [],
            },
        )


def run_task_with_hard_timeout(
    *,
    program_path: str,
    task: SymbolicRegressionPhysOscTaskSpec,
    stage1: bool,
) -> tuple[dict[str, Any], dict[str, Any]]:
    """Run one task worker in a subprocess and enforce a hard timeout."""
    stage_config = _stage_config(task, stage1=stage1)
    timeout_seconds = float(stage_config["timeout_seconds"])

    with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as handle:
        result_path = Path(handle.name)

    command = [
        sys.executable,
        str(Path(__file__).resolve()),
        "--worker",
        "--program",
        str(program_path),
        "--task-id",
        task.task_id,
        "--stage",
        "stage1" if stage1 else "full",
        "--result-path",
        str(result_path),
    ]
    process = subprocess.Popen(
        command,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True,
        cwd=str(REPO_ROOT),
        env=dict(os.environ),
    )

    try:
        stdout, stderr = process.communicate(timeout=timeout_seconds)
    except subprocess.TimeoutExpired:
        process.kill()
        stdout, stderr = process.communicate()
        result_path.unlink(missing_ok=True)
        error = f"Task worker timed out after {timeout_seconds:.0f}s"
        task_result = build_task_result(
            task,
            raw_metrics=None,
            error=error,
            timeout_seconds=timeout_seconds,
            data_source_mode="unknown",
            failure_kind="worker_timeout",
            failure_stage="worker_subprocess",
        )
        return task_result, {
            "task_id": task.task_id,
            "best_restart_seed": None,
            "successful_restarts": 0,
            "total_restarts": int(stage_config["num_restarts"]),
            "failed_restart_messages": [error],
            "failed_restart_details": [],
            "worker_stdout": stdout[-500:] if stdout else "",
            "worker_stderr": stderr[-500:] if stderr else "",
            "data_source_mode": "unknown",
            "failure_kind": "worker_timeout",
            "failure_stage": "worker_subprocess",
        }

    if process.returncode != 0 or not result_path.is_file():
        error = (
            f"Task worker failed for {task.task_id} with return code {process.returncode}. "
            f"stderr={stderr.strip()[:500]}"
        )
        task_result = build_task_result(
            task,
            raw_metrics=None,
            error=error,
            timeout_seconds=timeout_seconds,
            data_source_mode="unknown",
            failure_kind="worker_exception",
            failure_stage="worker_subprocess",
        )
        result_path.unlink(missing_ok=True)
        return task_result, {
            "task_id": task.task_id,
            "best_restart_seed": None,
            "successful_restarts": 0,
            "total_restarts": int(stage_config["num_restarts"]),
            "failed_restart_messages": [error],
            "failed_restart_details": [],
            "worker_stdout": stdout[-500:] if stdout else "",
            "worker_stderr": stderr[-500:] if stderr else "",
            "data_source_mode": "unknown",
            "failure_kind": "worker_exception",
            "failure_stage": "worker_subprocess",
        }

    try:
        payload = json.loads(result_path.read_text(encoding="utf-8"))
    finally:
        result_path.unlink(missing_ok=True)

    task_result = payload.get("task_result")
    task_artifacts = payload.get("task_artifacts")
    if not isinstance(task_result, dict) or not isinstance(task_artifacts, dict):
        error = f"Malformed worker payload for {task.task_id}"
        fallback_result = build_task_result(
            task,
            raw_metrics=None,
            error=error,
            timeout_seconds=timeout_seconds,
            data_source_mode="unknown",
            failure_kind="worker_exception",
            failure_stage="worker_subprocess",
        )
        return fallback_result, {
            "task_id": task.task_id,
            "best_restart_seed": None,
            "successful_restarts": 0,
            "total_restarts": int(stage_config["num_restarts"]),
            "failed_restart_messages": [error],
            "failed_restart_details": [],
            "data_source_mode": "unknown",
            "failure_kind": "worker_exception",
            "failure_stage": "worker_subprocess",
        }
    return task_result, task_artifacts


def _public_artifacts(
    *,
    selector: str,
    stage_name: str,
    task_results: list[dict[str, Any]],
    task_artifacts: list[dict[str, Any]],
) -> dict[str, Any]:
    artifacts: dict[str, Any] = {
        "task_selector": selector,
        "evaluation_stage": stage_name,
        "evaluation_mode": "shared" if len(task_results) > 1 else "task_specific",
        "selected_task_ids": [task_result["task_id"] for task_result in task_results],
        "task_results": task_results,
        "data_source_modes": {
            task_result["task_id"]: task_result.get("data_source_mode", "unknown")
            for task_result in task_results
        },
        "optimization_summary": {},
        "failure_kind_counts": {},
        "task_failure_summary": {},
        "data_source_mode_summary": {},
    }

    failure_kind_counts: Counter[str] = Counter()
    data_source_mode_summary: Counter[str] = Counter()
    for per_task_artifacts in task_artifacts:
        task_id = per_task_artifacts["task_id"]
        artifacts["optimization_summary"][task_id] = {
            "best_restart_seed": per_task_artifacts.get("best_restart_seed"),
            "successful_restarts": int(per_task_artifacts.get("successful_restarts", 0)),
            "total_restarts": int(per_task_artifacts.get("total_restarts", 0)),
        }
        failure_kind = str(per_task_artifacts.get("failure_kind", "unknown"))
        failure_stage = per_task_artifacts.get("failure_stage")
        data_source_mode = str(per_task_artifacts.get("data_source_mode", "unknown"))
        failure_kind_counts[failure_kind] += 1
        data_source_mode_summary[data_source_mode] += 1
        artifacts["task_failure_summary"][task_id] = {
            "failure_kind": failure_kind,
            "failure_stage": failure_stage,
            "error": next(
                (
                    task_result.get("error")
                    for task_result in task_results
                    if task_result.get("task_id") == task_id
                ),
                None,
            ),
            "data_source_mode": data_source_mode,
        }
        failed_messages = per_task_artifacts.get("failed_restart_messages") or []
        if failed_messages:
            artifacts.setdefault("failed_restart_messages", {})[task_id] = list(failed_messages)
        failed_details = per_task_artifacts.get("failed_restart_details") or []
        if failed_details:
            artifacts.setdefault("failed_restart_details", {})[task_id] = list(failed_details)

    artifacts["failure_kind_counts"] = dict(failure_kind_counts)
    artifacts["data_source_mode_summary"] = dict(data_source_mode_summary)

    if task_results:
        artifacts["aggregate_score_summary"] = {
            "mean_score": float(np.mean([result["metrics"]["score"] for result in task_results])),
            "mean_test_nmse": float(np.mean([result["metrics"]["test_nmse"] for result in task_results])),
            "mean_ood_nmse": float(np.mean([result["metrics"]["ood_nmse"] for result in task_results])),
        }
    return artifacts


def _evaluate(program_path: str, *, stage1: bool) -> EvaluationResult:
    selector = os.environ.get(
        SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR,
        SYMBOLIC_REGRESSION_PHYS_OSC_SHARED_SELECTOR,
    )
    selected_tasks = resolve_task_specs(selector)
    stage_name = "stage1" if stage1 else "full"

    task_results: list[dict[str, Any]] = []
    task_artifacts: list[dict[str, Any]] = []
    for task in selected_tasks:
        task_result, per_task_artifacts = run_task_with_hard_timeout(
            program_path=program_path,
            task=task,
            stage1=stage1,
        )
        task_results.append(task_result)
        task_artifacts.append(per_task_artifacts)

    artifacts = _public_artifacts(
        selector=selector,
        stage_name=stage_name,
        task_results=task_results,
        task_artifacts=task_artifacts,
    )
    if len(task_results) == 1:
        return EvaluationResult(metrics=dict(task_results[0]["metrics"]), artifacts=artifacts)
    return EvaluationResult(metrics=aggregate_task_results(task_results), artifacts=artifacts)


def evaluate(program_path: str) -> EvaluationResult:
    """Evaluate a candidate over the shared family or one selected task."""
    return _evaluate(program_path, stage1=False)


def evaluate_stage1(program_path: str) -> EvaluationResult:
    """Cheaper evaluation stage for symbolic-regression MT-STS."""
    return _evaluate(program_path, stage1=True)


def evaluate_stage2(program_path: str) -> EvaluationResult:
    """Full evaluation stage for symbolic-regression MT-STS."""
    return evaluate(program_path)


def preflight_check_symbolic_regression_phys_osc(
    task_ids: list[str] | None = None,
) -> dict[str, Any]:
    """Validate that the selected symbolic-regression tasks can load real data."""
    if task_ids is None:
        selector = os.environ.get(
            SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR,
            SYMBOLIC_REGRESSION_PHYS_OSC_SHARED_SELECTOR,
        )
        selected_tasks = resolve_task_specs(selector)
    else:
        selected_tasks = [SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID[task_id] for task_id in task_ids]
    return validate_symbolic_regression_phys_osc_assets(selected_tasks)


def _run_worker_cli(args: argparse.Namespace) -> int:
    payload = run_task_worker(
        program_path=args.program,
        task_id=args.task_id,
        stage1=args.stage == "stage1",
    )
    Path(args.result_path).write_text(
        json.dumps(payload, indent=2, default=_json_default),
        encoding="utf-8",
    )
    return 0


def _build_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Evaluate symbolic-regression MT-STS programs.")
    parser.add_argument("program_path", nargs="?", help="Path to the candidate program")
    parser.add_argument("--worker", action="store_true", help=argparse.SUPPRESS)
    parser.add_argument("--program", help=argparse.SUPPRESS)
    parser.add_argument("--task-id", choices=sorted(SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID), help=argparse.SUPPRESS)
    parser.add_argument("--stage", choices=("stage1", "full"), default="full", help=argparse.SUPPRESS)
    parser.add_argument("--result-path", help=argparse.SUPPRESS)
    return parser


def main(argv: list[str] | None = None) -> int:
    parser = _build_arg_parser()
    args = parser.parse_args(argv)

    if args.worker:
        if not args.program or not args.task_id or not args.result_path:
            raise SystemExit("--worker requires --program, --task-id, and --result-path")
        return _run_worker_cli(args)

    if not args.program_path:
        raise SystemExit("Usage: python evaluator.py <program_path>")

    evaluation_result = evaluate(args.program_path)
    payload = {
        "metrics": evaluation_result.metrics,
        "artifacts": evaluation_result.artifacts,
    }
    print(json.dumps(payload, indent=2, default=_json_default))
    return 0


if __name__ == "__main__":
    try:
        raise SystemExit(main())
    except Exception:
        print(traceback.format_exc(), file=sys.stderr)
        raise
