"""Data loading helpers for the MT-STS symbolic-regression physics subset."""

from __future__ import annotations

import json
import os
from pathlib import Path
import re
import sys
from typing import Any, Dict, Iterable, Mapping

import numpy as np

REPO_ROOT = Path(__file__).resolve().parents[2]
EXAMPLE_DIR = Path(__file__).resolve().parent
STANDALONE_SYMBOLIC_REGRESSION_DIR = REPO_ROOT / "examples" / "symbolic_regression"
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))
if str(STANDALONE_SYMBOLIC_REGRESSION_DIR) not in sys.path:
    sys.path.insert(0, str(STANDALONE_SYMBOLIC_REGRESSION_DIR))

from openevolve.multi_task_shared_then_specialize.symbolic_regression_phys_osc import (
    SymbolicRegressionPhysOscTaskSpec,
)

SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE_ENV_VAR = (
    "SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE"
)
SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR = "SYMBOLIC_REGRESSION_PROBLEMS_ROOT"
PROBLEM_METADATA_FILENAME = "problem_metadata.json"
REQUIRED_PROBLEM_FILENAMES: tuple[str, ...] = (
    "X_train_for_eval.npy",
    "y_train_for_eval.npy",
    "X_test_for_eval.npy",
    "y_test_for_eval.npy",
    "X_ood_test_for_eval.npy",
    "y_ood_test_for_eval.npy",
)

_TASK_DATA_CACHE: dict[tuple[str, str, str], Dict[str, Any]] = {}
_DATAMODULE_CACHE: dict[str, Any] = {}

try:
    from bench.datamodules import get_datamodule
except Exception:  # pragma: no cover - depends on optional benchmark deps
    get_datamodule = None


class SymbolicRegressionDataError(RuntimeError):
    """Base class for explicit symbolic-regression data-loading failures."""

    def __init__(
        self,
        message: str,
        *,
        failure_kind: str = "data_loading",
        failure_stage: str = "load_task_data",
        data_source_mode: str = "unknown",
    ) -> None:
        super().__init__(message)
        self.failure_kind = str(failure_kind)
        self.failure_stage = str(failure_stage)
        self.data_source_mode = str(data_source_mode)


class BenchmarkDependencyError(SymbolicRegressionDataError):
    def __init__(self, message: str) -> None:
        super().__init__(
            message,
            failure_kind="benchmark_dependency_missing",
            failure_stage="load_task_data",
            data_source_mode="benchmark_generation_fallback",
        )


class GeneratedProblemMissingError(SymbolicRegressionDataError):
    def __init__(self, message: str) -> None:
        super().__init__(
            message,
            failure_kind="generated_problem_missing",
            failure_stage="resolve_candidate",
            data_source_mode="generated_problem_dir",
        )


class ProblemMetadataError(SymbolicRegressionDataError):
    def __init__(self, message: str) -> None:
        super().__init__(
            message,
            failure_kind="metadata_missing",
            failure_stage="resolve_candidate",
            data_source_mode="generated_problem_dir",
        )


class GeneratedProblemRootsExhaustedError(SymbolicRegressionDataError):
    def __init__(self, message: str) -> None:
        super().__init__(
            message,
            failure_kind="generated_problem_invalid",
            failure_stage="resolve_candidate",
            data_source_mode="generated_problem_dir",
        )


def _task_identity(task_spec: SymbolicRegressionPhysOscTaskSpec) -> str:
    return f"{task_spec.task_id} ({task_spec.dataset_identifier}/{task_spec.equation_idx})"


def _copy_loaded_data(task_data: Mapping[str, Any]) -> Dict[str, Any]:
    copied: Dict[str, Any] = {"metadata": dict(task_data["metadata"])}
    for split in ("train", "test", "ood"):
        x_array, y_array = task_data[split]
        copied[split] = (np.array(x_array, copy=True), np.array(y_array, copy=True))
    return copied


def _cache_key(task_spec: SymbolicRegressionPhysOscTaskSpec) -> tuple[str, str, str]:
    fixture_flag = os.environ.get(
        SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE_ENV_VAR,
        "",
    ).strip()
    problems_root = os.environ.get(SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR, "").strip()
    return (task_spec.task_id, fixture_flag, problems_root)


def _canonical_equation_name(value: str) -> str:
    match = re.fullmatch(r"([A-Za-z_]+)0*([0-9]+)", value.strip())
    if match:
        return f"{match.group(1).lower()}{int(match.group(2))}"
    return value.strip().lower()


def _candidate_problem_roots() -> list[Path]:
    return [root for _, root in _candidate_problem_root_entries()]


def _candidate_problem_root_entries() -> list[tuple[str, Path]]:
    roots: list[tuple[str, Path]] = []
    env_root = os.environ.get(SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR)
    if env_root:
        roots.append((f"{SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR} env var", Path(env_root).expanduser()))
    roots.append(("local generated-problem cache", _dataset_cache_root()))
    roots.append(("repo-root problems", REPO_ROOT / "problems"))
    roots.append(("standalone symbolic-regression problems", STANDALONE_SYMBOLIC_REGRESSION_DIR / "problems"))
    deduped: list[tuple[str, Path]] = []
    seen: set[Path] = set()
    for label, root in roots:
        resolved = root.resolve()
        if resolved not in seen:
            seen.add(resolved)
            deduped.append((label, resolved))
    return deduped


def _resolve_problem_dir(
    root: Path,
    task_spec: SymbolicRegressionPhysOscTaskSpec,
) -> Path | None:
    dataset_dir = root / task_spec.dataset_identifier
    exact_dir = dataset_dir / task_spec.equation_idx
    if exact_dir.is_dir():
        return exact_dir
    if not dataset_dir.is_dir():
        return None

    expected_key = _canonical_equation_name(task_spec.equation_idx)
    matches = [
        candidate
        for candidate in dataset_dir.iterdir()
        if candidate.is_dir() and _canonical_equation_name(candidate.name) == expected_key
    ]
    if len(matches) == 1:
        return matches[0]
    if len(matches) > 1:
        raise SymbolicRegressionDataError(
            f"Ambiguous generated problem directories for {_task_identity(task_spec)} under "
            f"{dataset_dir}: {[candidate.name for candidate in matches]}",
            failure_kind="data_loading",
            failure_stage="resolve_candidate",
            data_source_mode="generated_problem_dir",
        )
    return None


def _ensure_split_shapes(
    split_name: str,
    x_array: np.ndarray,
    y_array: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
    x_array = np.asarray(x_array, dtype=float)
    y_array = np.asarray(y_array, dtype=float).reshape(-1)
    if x_array.ndim != 2:
        raise ValueError(f"{split_name} X must be 2D, got shape {x_array.shape}")
    if x_array.shape[1] != 3:
        raise ValueError(f"{split_name} X must have 3 columns in [x, t, v] order, got {x_array.shape}")
    if y_array.ndim != 1:
        raise ValueError(f"{split_name} y must be 1D, got shape {y_array.shape}")
    if x_array.shape[0] != y_array.shape[0]:
        raise ValueError(
            f"{split_name} sample count mismatch: X has {x_array.shape[0]}, y has {y_array.shape[0]}"
        )
    if not np.all(np.isfinite(x_array)):
        raise ValueError(f"{split_name} X contains non-finite values")
    if not np.all(np.isfinite(y_array)):
        raise ValueError(f"{split_name} y contains non-finite values")
    return x_array, y_array


def _validate_metadata(
    task_spec: SymbolicRegressionPhysOscTaskSpec,
    *,
    input_var_names: tuple[str, ...] | list[str],
    output_var_name: str,
    data_source_mode: str,
) -> None:
    normalized_inputs = tuple(str(name) for name in input_var_names)
    if normalized_inputs != tuple(task_spec.input_var_names):
        raise SymbolicRegressionDataError(
            f"{_task_identity(task_spec)} metadata mismatch: expected input vars "
            f"{task_spec.input_var_names}, got {normalized_inputs}",
            failure_kind="data_loading",
            failure_stage="resolve_candidate",
            data_source_mode=data_source_mode,
        )
    if str(output_var_name) != str(task_spec.output_var_name):
        raise SymbolicRegressionDataError(
            f"{_task_identity(task_spec)} metadata mismatch: expected output var "
            f"{task_spec.output_var_name}, got {output_var_name}",
            failure_kind="data_loading",
            failure_stage="resolve_candidate",
            data_source_mode=data_source_mode,
        )


def _problem_metadata_from_cached_file(problem_dir: Path) -> dict[str, Any] | None:
    metadata_path = problem_dir / PROBLEM_METADATA_FILENAME
    if not metadata_path.is_file():
        return None
    try:
        payload = json.loads(metadata_path.read_text(encoding="utf-8"))
    except (OSError, json.JSONDecodeError):
        return None
    return payload if isinstance(payload, dict) else None


def _require_problem_metadata(problem_dir: Path) -> dict[str, Any]:
    payload = _problem_metadata_from_cached_file(problem_dir)
    metadata_path = problem_dir / PROBLEM_METADATA_FILENAME
    if payload is None:
        raise ProblemMetadataError(
            "Generated symbolic-regression problem directory "
            f"{problem_dir} is missing valid metadata at {metadata_path}. "
            "Refusing to assume [x, t, v] -> dv_dt semantics without explicit metadata."
        )

    if "input_var_names" not in payload:
        raise ProblemMetadataError(
            "Generated symbolic-regression problem directory "
            f"{problem_dir} has invalid metadata at {metadata_path}: "
            "missing required key 'input_var_names'."
        )
    if "output_var_name" not in payload:
        raise ProblemMetadataError(
            "Generated symbolic-regression problem directory "
            f"{problem_dir} has invalid metadata at {metadata_path}: "
            "missing required key 'output_var_name'."
        )

    input_var_names = payload["input_var_names"]
    if not isinstance(input_var_names, (list, tuple)) or len(input_var_names) != 3:
        raise ProblemMetadataError(
            "Generated symbolic-regression problem directory "
            f"{problem_dir} has invalid metadata at {metadata_path}: "
            "'input_var_names' must be a list or tuple of length 3."
        )
    if any(not isinstance(name, str) or not name for name in input_var_names):
        raise ProblemMetadataError(
            "Generated symbolic-regression problem directory "
            f"{problem_dir} has invalid metadata at {metadata_path}: "
            "'input_var_names' must contain only non-empty strings."
        )

    output_var_name = payload["output_var_name"]
    if not isinstance(output_var_name, str) or not output_var_name:
        raise ProblemMetadataError(
            "Generated symbolic-regression problem directory "
            f"{problem_dir} has invalid metadata at {metadata_path}: "
            "'output_var_name' must be a non-empty string."
        )
    return payload


def _dataset_cache_root() -> Path:
    return EXAMPLE_DIR / "_problem_cache"


def _write_problem_cache(
    task_spec: SymbolicRegressionPhysOscTaskSpec,
    task_data: Mapping[str, Any],
) -> None:
    cache_dir = _dataset_cache_root() / task_spec.dataset_identifier / task_spec.equation_idx
    cache_dir.mkdir(parents=True, exist_ok=True)
    np.save(cache_dir / "X_train_for_eval.npy", task_data["train"][0])
    np.save(cache_dir / "y_train_for_eval.npy", task_data["train"][1])
    np.save(cache_dir / "X_test_for_eval.npy", task_data["test"][0])
    np.save(cache_dir / "y_test_for_eval.npy", task_data["test"][1])
    np.save(cache_dir / "X_ood_test_for_eval.npy", task_data["ood"][0])
    np.save(cache_dir / "y_ood_test_for_eval.npy", task_data["ood"][1])
    metadata_payload = dict(task_data["metadata"])
    metadata_payload["input_var_names"] = list(metadata_payload["input_var_names"])
    (cache_dir / PROBLEM_METADATA_FILENAME).write_text(
        json.dumps(metadata_payload, indent=2),
        encoding="utf-8",
    )


def _synthetic_formula(task_spec: SymbolicRegressionPhysOscTaskSpec, x_array: np.ndarray) -> np.ndarray:
    pos = x_array[:, 0]
    t_val = x_array[:, 1]
    vel = x_array[:, 2]
    if task_spec.task_id == "sr_po11":
        return -1.2 * pos - 0.18 * pos**3 - 0.35 * vel + 0.55 * np.sin(0.8 * t_val)
    if task_spec.task_id == "sr_po17":
        return (
            -0.95 * pos
            - 0.22 * pos**3
            - 0.28 * vel
            + 0.18 * pos * t_val
            + 0.40 * np.sin(1.1 * t_val)
        )
    if task_spec.task_id == "sr_po30":
        return (
            -1.05 * pos
            - 0.16 * pos**3
            - 0.42 * vel
            - 0.08 * vel**3
            + 0.24 * np.sin(vel)
            + 0.30 * np.cos(0.6 * t_val)
        )
    if task_spec.task_id == "sr_po37":
        exp_term = np.exp(-np.clip(np.abs(pos), 0.0, 6.0))
        return (
            -0.85 * pos
            - 0.20 * pos**3
            - 0.30 * vel
            + 0.16 * pos * t_val
            + 0.22 * exp_term * vel
            + 0.38 * np.sin(0.9 * t_val)
        )
    raise ValueError(f"Unsupported synthetic task id: {task_spec.task_id}")


def _synthetic_split(
    task_spec: SymbolicRegressionPhysOscTaskSpec,
    *,
    seed: int,
    count: int,
    pos_range: tuple[float, float],
    time_range: tuple[float, float],
    vel_range: tuple[float, float],
) -> tuple[np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    pos = rng.uniform(pos_range[0], pos_range[1], count)
    t_val = rng.uniform(time_range[0], time_range[1], count)
    vel = rng.uniform(vel_range[0], vel_range[1], count)
    x_array = np.column_stack([pos, t_val, vel]).astype(float, copy=False)
    y_array = _synthetic_formula(task_spec, x_array).astype(float, copy=False)
    return x_array, y_array


def _load_synthetic_fixture(task_spec: SymbolicRegressionPhysOscTaskSpec) -> Dict[str, Any]:
    train = _synthetic_split(
        task_spec,
        seed=1000 + task_spec.task_index,
        count=96,
        pos_range=(-1.5, 1.5),
        time_range=(0.0, 4.0),
        vel_range=(-1.2, 1.2),
    )
    test = _synthetic_split(
        task_spec,
        seed=2000 + task_spec.task_index,
        count=96,
        pos_range=(-1.8, 1.8),
        time_range=(0.0, 5.0),
        vel_range=(-1.4, 1.4),
    )
    ood = _synthetic_split(
        task_spec,
        seed=3000 + task_spec.task_index,
        count=96,
        pos_range=(-2.8, 2.8),
        time_range=(2.0, 7.0),
        vel_range=(-2.1, 2.1),
    )
    metadata = {
        "data_source_mode": "synthetic_fixture",
        "dataset_identifier": task_spec.dataset_identifier,
        "equation_idx": task_spec.equation_idx,
        "input_var_names": tuple(task_spec.input_var_names),
        "output_var_name": task_spec.output_var_name,
    }
    return {
        "train": train,
        "test": test,
        "ood": ood,
        "metadata": metadata,
    }


def _build_task_data(
    *,
    task_spec: SymbolicRegressionPhysOscTaskSpec,
    train: tuple[np.ndarray, np.ndarray],
    test: tuple[np.ndarray, np.ndarray],
    ood: tuple[np.ndarray, np.ndarray],
    metadata: Mapping[str, Any],
) -> Dict[str, Any]:
    _validate_metadata(
        task_spec,
        input_var_names=metadata["input_var_names"],
        output_var_name=metadata["output_var_name"],
        data_source_mode=str(metadata["data_source_mode"]),
    )
    data_source_mode = str(metadata["data_source_mode"])
    try:
        x_train, y_train = _ensure_split_shapes("train", train[0], train[1])
        x_test, y_test = _ensure_split_shapes("test", test[0], test[1])
        x_ood, y_ood = _ensure_split_shapes("ood", ood[0], ood[1])
    except ValueError as exc:
        raise SymbolicRegressionDataError(
            f"Invalid cached arrays for {_task_identity(task_spec)} loaded from "
            f"{data_source_mode}: {exc}",
            failure_kind="data_loading",
            failure_stage="load_task_data",
            data_source_mode=data_source_mode,
        ) from exc
    return {
        "train": (x_train, y_train),
        "test": (x_test, y_test),
        "ood": (x_ood, y_ood),
        "metadata": {
            "data_source_mode": str(metadata["data_source_mode"]),
            "dataset_identifier": str(metadata["dataset_identifier"]),
            "equation_idx": str(metadata["equation_idx"]),
            "input_var_names": tuple(metadata["input_var_names"]),
            "output_var_name": str(metadata["output_var_name"]),
        },
    }


def _expected_problem_dir(root: Path, task_spec: SymbolicRegressionPhysOscTaskSpec) -> Path:
    return root / task_spec.dataset_identifier / task_spec.equation_idx


def _problem_file_map(problem_dir: Path) -> dict[str, Path]:
    return {filename: problem_dir / filename for filename in REQUIRED_PROBLEM_FILENAMES}


def _format_missing_expected_dirs(
    task_spec: SymbolicRegressionPhysOscTaskSpec,
    root_entries: Iterable[tuple[str, Path]],
) -> str:
    lines = []
    for label, root in root_entries:
        lines.append(f"- {label}: expected {_expected_problem_dir(root, task_spec)}")
    return "\n".join(lines)


def _format_generated_problem_root_failures(
    task_spec: SymbolicRegressionPhysOscTaskSpec,
    root_failures: Iterable[tuple[str, Path, SymbolicRegressionDataError]],
) -> str:
    lines = [
        f"Generated symbolic-regression problem directories for {_task_identity(task_spec)} "
        "were found but could not be reused:",
    ]
    for label, root, exc in root_failures:
        lines.append(
            f"- {label} ({_expected_problem_dir(root, task_spec)}): "
            f"[{exc.failure_kind} @ {exc.failure_stage}] {exc}"
        )
    return "\n".join(lines)


def _load_problem_array(
    path: Path,
    *,
    task_spec: SymbolicRegressionPhysOscTaskSpec,
) -> np.ndarray:
    try:
        return np.load(path)
    except Exception as exc:
        raise SymbolicRegressionDataError(
            f"Failed to load cached symbolic-regression array for {_task_identity(task_spec)} "
            f"from {path}: {exc}",
            failure_kind="data_loading",
            failure_stage="load_task_data",
            data_source_mode="generated_problem_dir",
        ) from exc


def _load_from_generated_problem_dir(
    task_spec: SymbolicRegressionPhysOscTaskSpec,
) -> Dict[str, Any] | None:
    root_entries = _candidate_problem_root_entries()
    env_root = os.environ.get(SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR, "").strip()
    implicit_root_failures: list[tuple[str, Path, SymbolicRegressionDataError]] = []

    for root_label, root in root_entries:
        explicit_env_root = bool(env_root) and root_label.startswith(
            SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR
        )
        try:
            problem_dir = _resolve_problem_dir(root, task_spec)
            if problem_dir is None:
                if explicit_env_root:
                    raise GeneratedProblemMissingError(
                        f"{SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR}={root} does not contain "
                        f"{_task_identity(task_spec)}.\n"
                        f"Expected directory: {_expected_problem_dir(root, task_spec)}"
                    )
                continue
            required_files = _problem_file_map(problem_dir)
            missing_files = [str(path) for path in required_files.values() if not path.is_file()]
            if missing_files:
                raise GeneratedProblemMissingError(
                    f"Generated symbolic-regression problem directory {problem_dir} for "
                    f"{_task_identity(task_spec)} is missing required files:\n"
                    + "\n".join(f"- {path}" for path in missing_files)
                )

            cached_metadata = _require_problem_metadata(problem_dir)
            metadata = {
                "data_source_mode": "generated_problem_dir",
                "dataset_identifier": task_spec.dataset_identifier,
                "equation_idx": task_spec.equation_idx,
                "input_var_names": tuple(cached_metadata["input_var_names"]),
                "output_var_name": cached_metadata["output_var_name"],
            }
            return _build_task_data(
                task_spec=task_spec,
                train=(
                    _load_problem_array(required_files["X_train_for_eval.npy"], task_spec=task_spec),
                    _load_problem_array(required_files["y_train_for_eval.npy"], task_spec=task_spec),
                ),
                test=(
                    _load_problem_array(required_files["X_test_for_eval.npy"], task_spec=task_spec),
                    _load_problem_array(required_files["y_test_for_eval.npy"], task_spec=task_spec),
                ),
                ood=(
                    _load_problem_array(required_files["X_ood_test_for_eval.npy"], task_spec=task_spec),
                    _load_problem_array(required_files["y_ood_test_for_eval.npy"], task_spec=task_spec),
                ),
                metadata=metadata,
            )
        except SymbolicRegressionDataError as exc:
            if explicit_env_root:
                raise
            implicit_root_failures.append((root_label, root, exc))
            continue
    if implicit_root_failures:
        raise GeneratedProblemRootsExhaustedError(
            _format_generated_problem_root_failures(task_spec, implicit_root_failures)
        )
    return None


def _load_phys_osc_datamodule():
    cache_key = "phys_osc"
    if cache_key in _DATAMODULE_CACHE:
        return _DATAMODULE_CACHE[cache_key]
    if get_datamodule is None:
        raise BenchmarkDependencyError(
            "bench.datamodules.get_datamodule is unavailable. Install the benchmark "
            "datamodule dependencies, set SYMBOLIC_REGRESSION_PROBLEMS_ROOT to a valid "
            "generated problems root, generate local problems first, or use "
            "SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE=1 for smoke tests."
        )
    try:
        datamodule = get_datamodule("phys_osc", None)
        datamodule.setup()
    except SymbolicRegressionDataError:
        raise
    except Exception as exc:
        raise BenchmarkDependencyError(
            "Failed to initialize the phys_osc benchmark datamodule via "
            "bench.datamodules.get_datamodule('phys_osc', None). Install the benchmark "
            "dependencies, ensure the benchmark assets are available, set "
            "SYMBOLIC_REGRESSION_PROBLEMS_ROOT to a valid generated problems root, or use "
            "SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE=1 for smoke tests. "
            f"Underlying error: {type(exc).__name__}: {exc}"
        ) from exc
    _DATAMODULE_CACHE[cache_key] = datamodule
    return datamodule


def _extract_from_problem(
    task_spec: SymbolicRegressionPhysOscTaskSpec,
    problem: Any,
) -> Dict[str, Any]:
    symbols = tuple(str(symbol) for symbol in problem.gt_equation.symbols)
    properties = tuple(str(prop) for prop in problem.gt_equation.symbol_properties)
    input_indices = [index for index, prop in enumerate(properties) if prop == "V"]
    output_indices = [index for index, prop in enumerate(properties) if prop == "O"]
    if len(output_indices) != 1:
        raise SymbolicRegressionDataError(
            f"{_task_identity(task_spec)} expected one output variable, found "
            f"{len(output_indices)} in benchmark metadata",
            failure_kind="data_loading",
            failure_stage="load_task_data",
            data_source_mode="benchmark_generation_fallback",
        )
    input_var_names = tuple(symbols[index] for index in input_indices)
    output_var_name = symbols[output_indices[0]]
    _validate_metadata(
        task_spec,
        input_var_names=input_var_names,
        output_var_name=output_var_name,
        data_source_mode="benchmark_generation_fallback",
    )

    train_samples = np.asarray(problem.train_samples, dtype=float)
    test_samples = np.asarray(problem.test_samples, dtype=float)
    ood_samples = np.asarray(problem.ood_test_samples, dtype=float)

    return _build_task_data(
        task_spec=task_spec,
        train=(
            train_samples[:, input_indices],
            train_samples[:, output_indices[0]],
        ),
        test=(
            test_samples[:, input_indices],
            test_samples[:, output_indices[0]],
        ),
        ood=(
            ood_samples[:, input_indices],
            ood_samples[:, output_indices[0]],
        ),
        metadata={
            "data_source_mode": "benchmark_generation_fallback",
            "dataset_identifier": problem.dataset_identifier,
            "equation_idx": problem.equation_idx,
            "input_var_names": input_var_names,
            "output_var_name": output_var_name,
        },
    )


def _load_from_datamodule(task_spec: SymbolicRegressionPhysOscTaskSpec) -> Dict[str, Any]:
    datamodule = _load_phys_osc_datamodule()
    problem = None
    if hasattr(datamodule, "name2id") and task_spec.equation_idx in datamodule.name2id:
        problem = datamodule.problems[datamodule.name2id[task_spec.equation_idx]]
    else:
        for candidate in getattr(datamodule, "problems", ()):
            if getattr(candidate, "equation_idx", None) == task_spec.equation_idx:
                problem = candidate
                break
    if problem is None:
        available = [getattr(candidate, "equation_idx", "?") for candidate in datamodule.problems]
        raise SymbolicRegressionDataError(
            f"Could not resolve {_task_identity(task_spec)} inside the phys_osc datamodule. "
            f"Available equation_idx values include: {available[:8]}",
            failure_kind="data_loading",
            failure_stage="load_task_data",
            data_source_mode="benchmark_generation_fallback",
        )

    task_data = _extract_from_problem(task_spec, problem)
    _write_problem_cache(task_spec, task_data)
    return task_data


def load_task_data(task_spec: SymbolicRegressionPhysOscTaskSpec) -> Dict[str, Any]:
    """Load train/test/OOD arrays for one symbolic-regression MT-STS task."""
    cache_key = _cache_key(task_spec)
    if cache_key in _TASK_DATA_CACHE:
        return _copy_loaded_data(_TASK_DATA_CACHE[cache_key])

    if os.environ.get(SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE_ENV_VAR, "").strip() == "1":
        task_data = _load_synthetic_fixture(task_spec)
    else:
        generated_problem_roots_error: GeneratedProblemRootsExhaustedError | None = None
        try:
            try:
                task_data = _load_from_generated_problem_dir(task_spec)
            except GeneratedProblemRootsExhaustedError as exc:
                generated_problem_roots_error = exc
                task_data = None
            if task_data is None:
                task_data = _load_from_datamodule(task_spec)
        except BenchmarkDependencyError as exc:
            if os.environ.get(SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR, "").strip():
                raise
            lines = [f"Could not load {_task_identity(task_spec)}."]
            if generated_problem_roots_error is None:
                lines.extend(
                    [
                        "No reusable generated problem directory was found under the default search roots:",
                        _format_missing_expected_dirs(task_spec, _candidate_problem_root_entries()),
                    ]
                )
            else:
                lines.append(str(generated_problem_roots_error))
            lines.extend(
                [
                    "The benchmark datamodule dependency is also unavailable. Set "
                    f"{SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR} to a valid problems root, "
                    "generate local problems first so the cache can be reused automatically, "
                    "install the benchmark datamodule dependencies, or use "
                    f"{SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE_ENV_VAR}=1 for smoke tests.",
                    f"Underlying error: {exc}",
                ]
            )
            raise BenchmarkDependencyError("\n".join(lines)) from exc

    _TASK_DATA_CACHE[cache_key] = task_data
    return _copy_loaded_data(task_data)


def validate_symbolic_regression_phys_osc_assets(
    task_specs: Iterable[SymbolicRegressionPhysOscTaskSpec],
) -> Dict[str, Any]:
    """Validate that the selected symbolic-regression tasks are resolvable."""
    selected_tasks = list(task_specs)
    if os.environ.get(SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE_ENV_VAR, "").strip() == "1":
        return {
            "family": "symbolic_regression_phys_osc",
            "synthetic_fixture_enabled": True,
            "selected_task_ids": [task.task_id for task in selected_tasks],
            "data_source_modes": {
                task.task_id: "synthetic_fixture" for task in selected_tasks
            },
            "data_source_mode_summary": {"synthetic_fixture": len(selected_tasks)},
        }

    data_source_modes: dict[str, str] = {}
    failures: list[tuple[str, SymbolicRegressionDataError]] = []
    for task in selected_tasks:
        try:
            task_data = load_task_data(task)
        except SymbolicRegressionDataError as exc:
            failures.append((task.task_id, exc))
        else:
            data_source_modes[task.task_id] = str(task_data["metadata"]["data_source_mode"])

    if failures:
        lines = [
            "symbolic-regression benchmark assets are unavailable or invalid for the selected MT-STS tasks.",
            "Fix one of the following before rerunning:",
            f"- set {SYMBOLIC_REGRESSION_PROBLEMS_ROOT_ENV_VAR} to a valid generated problems root",
            "- generate local problems first so the family cache can be reused automatically",
            "- install the benchmark datamodule dependencies",
            f"- or use {SYMBOLIC_REGRESSION_PHYS_OSC_USE_SYNTHETIC_FIXTURE_ENV_VAR}=1 for smoke tests",
            "Per-task failures:",
        ]
        for task_id, exc in failures:
            lines.append(
                f"- {task_id}: [{exc.failure_kind} @ {exc.failure_stage}] {exc}"
            )
        raise SymbolicRegressionDataError(
            "\n".join(lines),
            failure_kind="data_loading",
            failure_stage="load_task_data",
            data_source_mode="unknown",
        )

    mode_summary: dict[str, int] = {}
    for mode in data_source_modes.values():
        mode_summary[mode] = mode_summary.get(mode, 0) + 1
    return {
        "family": "symbolic_regression_phys_osc",
        "synthetic_fixture_enabled": False,
        "selected_task_ids": [task.task_id for task in selected_tasks],
        "data_source_modes": data_source_modes,
        "data_source_mode_summary": mode_summary,
    }
