"""Canonical symbolic-regression physics-oscillator task family for MT-STS."""

from __future__ import annotations

from dataclasses import dataclass
import math
from typing import Any, Dict, Iterable, List, Mapping, Optional


SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR = (
    "SYMBOLIC_REGRESSION_PHYS_OSC_TASK_ID"
)
SYMBOLIC_REGRESSION_PHYS_OSC_SHARED_SELECTOR = "all"
SYMBOLIC_REGRESSION_PHYS_OSC_MAX_PARAMETER_BUDGET = 10
DEFAULT_STAGE1_TIMEOUT_SECONDS = 45.0
DEFAULT_FULL_TIMEOUT_SECONDS = 120.0

SYMBOLIC_REGRESSION_PHYS_OSC_METRIC_KEYS: tuple[str, ...] = (
    "train_nmse",
    "test_nmse",
    "ood_nmse",
    "train_r2",
    "test_r2",
    "ood_r2",
    "train_score",
    "test_score",
    "ood_score",
    "parsimony_score",
    "restart_success_rate",
    "num_params_used",
    "eval_time",
    "score",
    "combined_score",
)


@dataclass(frozen=True)
class SymbolicRegressionPhysOscTaskSpec:
    """Stable task identity for the MT-STS symbolic-regression family."""

    task_id: str
    task_index: int
    display_name: str
    dataset_identifier: str
    equation_idx: str
    input_var_names: tuple[str, ...]
    output_var_name: str
    parameter_budget: int
    num_restarts_full: int
    num_restarts_stage1: int
    maxiter_full: int
    maxiter_stage1: int
    timeout_seconds_full: float = DEFAULT_FULL_TIMEOUT_SECONDS
    timeout_seconds_stage1: float = DEFAULT_STAGE1_TIMEOUT_SECONDS

    def to_spec_dict(self) -> Dict[str, Any]:
        return {
            "display_name": self.display_name,
            "dataset_identifier": self.dataset_identifier,
            "equation_idx": self.equation_idx,
            "input_var_names": list(self.input_var_names),
            "output_var_name": self.output_var_name,
            "parameter_budget": self.parameter_budget,
        }


SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS: tuple[SymbolicRegressionPhysOscTaskSpec, ...] = (
    SymbolicRegressionPhysOscTaskSpec(
        task_id="sr_po11",
        task_index=0,
        display_name="PhysOscPO11",
        dataset_identifier="phys_osc",
        equation_idx="PO11",
        input_var_names=("x", "t", "v"),
        output_var_name="dv_dt",
        parameter_budget=10,
        num_restarts_full=6,
        num_restarts_stage1=3,
        maxiter_full=400,
        maxiter_stage1=150,
    ),
    SymbolicRegressionPhysOscTaskSpec(
        task_id="sr_po17",
        task_index=1,
        display_name="PhysOscPO17",
        dataset_identifier="phys_osc",
        equation_idx="PO17",
        input_var_names=("x", "t", "v"),
        output_var_name="dv_dt",
        parameter_budget=10,
        num_restarts_full=6,
        num_restarts_stage1=3,
        maxiter_full=400,
        maxiter_stage1=150,
    ),
    SymbolicRegressionPhysOscTaskSpec(
        task_id="sr_po30",
        task_index=2,
        display_name="PhysOscPO30",
        dataset_identifier="phys_osc",
        equation_idx="PO30",
        input_var_names=("x", "t", "v"),
        output_var_name="dv_dt",
        parameter_budget=10,
        num_restarts_full=6,
        num_restarts_stage1=3,
        maxiter_full=400,
        maxiter_stage1=150,
    ),
    SymbolicRegressionPhysOscTaskSpec(
        task_id="sr_po37",
        task_index=3,
        display_name="PhysOscPO37",
        dataset_identifier="phys_osc",
        equation_idx="PO37",
        input_var_names=("x", "t", "v"),
        output_var_name="dv_dt",
        parameter_budget=10,
        num_restarts_full=6,
        num_restarts_stage1=3,
        maxiter_full=400,
        maxiter_stage1=150,
    ),
)

SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID: Dict[str, SymbolicRegressionPhysOscTaskSpec] = {
    task.task_id: task for task in SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS
}


def resolve_task_specs(selector: Optional[str]) -> List[SymbolicRegressionPhysOscTaskSpec]:
    """Resolve a task selector into concrete symbolic-regression task specs."""
    normalized = (selector or SYMBOLIC_REGRESSION_PHYS_OSC_SHARED_SELECTOR).strip()
    if not normalized or normalized == SYMBOLIC_REGRESSION_PHYS_OSC_SHARED_SELECTOR:
        return list(SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS)
    if normalized not in SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID:
        available = ", ".join(task.task_id for task in SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS)
        raise ValueError(
            f"Unknown symbolic-regression physics-oscillator task '{normalized}'. "
            f"Available: {available}"
        )
    return [SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID[normalized]]


def _coerce_finite_float(value: Any, default: float) -> float:
    try:
        numeric = float(value)
    except (TypeError, ValueError):
        return default
    if not math.isfinite(numeric):
        return default
    return numeric


def _coerce_non_negative_float(value: Any, default: float) -> float:
    return max(0.0, _coerce_finite_float(value, default))


def _clamp_float(value: float, low: float, high: float) -> float:
    return max(low, min(high, value))


def _clamp_int(value: Any, default: int, low: int, high: int) -> int:
    try:
        numeric = int(value)
    except (TypeError, ValueError):
        numeric = int(default)
    return max(int(low), min(int(high), numeric))


def empty_task_metrics(
    timeout_seconds: float = DEFAULT_FULL_TIMEOUT_SECONDS,
) -> Dict[str, float]:
    """Finite failure metrics for symbolic-regression task results."""
    return {
        "train_nmse": 100000.0,
        "test_nmse": 100000.0,
        "ood_nmse": 100000.0,
        "train_r2": -1.0,
        "test_r2": -1.0,
        "ood_r2": -1.0,
        "train_score": 0.0,
        "test_score": 0.0,
        "ood_score": 0.0,
        "parsimony_score": 0.0,
        "restart_success_rate": 0.0,
        "num_params_used": float(SYMBOLIC_REGRESSION_PHYS_OSC_MAX_PARAMETER_BUDGET),
        "eval_time": max(0.0, float(timeout_seconds)),
        "score": 0.0,
        "combined_score": 0.0,
    }


def normalize_task_metrics(
    raw_metrics: Mapping[str, Any] | None,
    *,
    timeout_seconds: float = DEFAULT_FULL_TIMEOUT_SECONDS,
) -> Dict[str, float]:
    """Normalize raw evaluator outputs into the stable task metric schema."""
    defaults = empty_task_metrics(timeout_seconds=timeout_seconds)
    raw_metrics = raw_metrics or {}

    train_nmse = _coerce_non_negative_float(raw_metrics.get("train_nmse"), defaults["train_nmse"])
    test_nmse = _coerce_non_negative_float(raw_metrics.get("test_nmse"), defaults["test_nmse"])
    raw_ood_nmse = raw_metrics.get("ood_nmse")
    if raw_ood_nmse is None:
        ood_nmse = test_nmse
    else:
        ood_nmse = _coerce_non_negative_float(raw_ood_nmse, test_nmse)

    train_r2 = _coerce_finite_float(raw_metrics.get("train_r2"), defaults["train_r2"])
    test_r2 = _coerce_finite_float(raw_metrics.get("test_r2"), defaults["test_r2"])
    raw_ood_r2 = raw_metrics.get("ood_r2")
    if raw_ood_r2 is None:
        ood_r2 = test_r2
    else:
        ood_r2 = _coerce_finite_float(raw_ood_r2, test_r2)

    num_params_used = float(
        _clamp_int(
            raw_metrics.get("num_params_used"),
            SYMBOLIC_REGRESSION_PHYS_OSC_MAX_PARAMETER_BUDGET,
            1,
            SYMBOLIC_REGRESSION_PHYS_OSC_MAX_PARAMETER_BUDGET,
        )
    )
    eval_time = _coerce_non_negative_float(raw_metrics.get("eval_time"), defaults["eval_time"])

    successful_restarts = _coerce_non_negative_float(
        raw_metrics.get("successful_restarts"),
        0.0,
    )
    total_restarts = _coerce_non_negative_float(raw_metrics.get("total_restarts"), 0.0)
    if total_restarts > 0.0:
        restart_success_rate = _clamp_float(successful_restarts / total_restarts, 0.0, 1.0)
    else:
        restart_success_rate = _clamp_float(
            _coerce_non_negative_float(
                raw_metrics.get("restart_success_rate"),
                defaults["restart_success_rate"],
            ),
            0.0,
            1.0,
        )

    if restart_success_rate <= 0.0:
        return {
            **defaults,
            "num_params_used": num_params_used,
            "eval_time": eval_time,
        }

    train_score = 1.0 / (1.0 + train_nmse) if math.isfinite(train_nmse) else 0.0
    test_score = 1.0 / (1.0 + test_nmse) if math.isfinite(test_nmse) else 0.0
    ood_score = 1.0 / (1.0 + ood_nmse) if math.isfinite(ood_nmse) else test_score
    parsimony_score = 1.0 / (1.0 + 0.25 * max(0.0, num_params_used - 1.0))
    score = 0.35 * test_score + 0.55 * ood_score + 0.10 * parsimony_score

    return {
        "train_nmse": train_nmse,
        "test_nmse": test_nmse,
        "ood_nmse": ood_nmse,
        "train_r2": train_r2,
        "test_r2": test_r2,
        "ood_r2": ood_r2,
        "train_score": train_score,
        "test_score": test_score,
        "ood_score": ood_score,
        "parsimony_score": parsimony_score,
        "restart_success_rate": restart_success_rate,
        "num_params_used": num_params_used,
        "eval_time": eval_time,
        "score": score,
        "combined_score": score,
    }


def build_task_result(
    task: SymbolicRegressionPhysOscTaskSpec,
    *,
    raw_metrics: Mapping[str, Any] | None,
    error: Optional[str] = None,
    timeout_seconds: float = DEFAULT_FULL_TIMEOUT_SECONDS,
    data_source_mode: Optional[str] = None,
    failure_kind: Optional[str] = None,
    failure_stage: Optional[str] = None,
) -> Dict[str, Any]:
    """Build one stable per-task symbolic-regression artifact entry."""
    metrics = normalize_task_metrics(raw_metrics, timeout_seconds=timeout_seconds)
    final_task_score = float(metrics["score"])
    metrics["combined_score"] = final_task_score
    resolved_failure_kind = str(failure_kind or ("unknown" if error else "none"))
    resolved_failure_stage = str(failure_stage or ("unknown" if error else "none"))
    result = {
        "task_id": task.task_id,
        "task_index": task.task_index,
        "spec": task.to_spec_dict(),
        "metrics": metrics,
        "case_score": final_task_score,
        "final_task_score": final_task_score,
        "error": error,
        "failure_kind": resolved_failure_kind,
        "failure_stage": resolved_failure_stage,
    }
    if data_source_mode is not None:
        result["data_source_mode"] = str(data_source_mode)
    return result


def aggregate_task_results(task_results: Iterable[Mapping[str, Any]]) -> Dict[str, float]:
    """Aggregate per-task symbolic-regression results into one shared payload."""
    normalized_results = list(task_results)
    if not normalized_results:
        return {
            **empty_task_metrics(),
            "task_count": 0.0,
            "successful_task_count": 0.0,
            "failed_task_count": 0.0,
        }

    aggregate: Dict[str, float] = {}
    for key in SYMBOLIC_REGRESSION_PHYS_OSC_METRIC_KEYS:
        aggregate[key] = sum(
            float(result["metrics"][key]) for result in normalized_results
        ) / len(normalized_results)
    aggregate["score"] = aggregate["combined_score"]
    aggregate["task_count"] = float(len(normalized_results))
    aggregate["successful_task_count"] = float(
        sum(1 for result in normalized_results if not result.get("error"))
    )
    aggregate["failed_task_count"] = (
        float(len(normalized_results)) - aggregate["successful_task_count"]
    )
    return aggregate


def extract_task_result(artifacts: Mapping[str, Any], task_id: str) -> Optional[Dict[str, Any]]:
    """Extract one strict per-task artifact entry from stored evaluation artifacts."""
    task_results = artifacts.get("task_results")
    if not isinstance(task_results, list):
        return None

    for task_result in task_results:
        if not isinstance(task_result, Mapping):
            continue
        if task_result.get("task_id") != task_id:
            continue

        metrics = task_result.get("metrics")
        if not isinstance(metrics, Mapping):
            return None
        for key in SYMBOLIC_REGRESSION_PHYS_OSC_METRIC_KEYS:
            if key not in metrics:
                return None
            numeric = _coerce_finite_float(metrics.get(key), float("nan"))
            if not math.isfinite(numeric):
                return None

        task = SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID.get(task_id)
        if task is None:
            return dict(task_result)
        return build_task_result(
            task,
            raw_metrics=metrics,
            error=task_result.get("error"),
            data_source_mode=task_result.get("data_source_mode"),
            failure_kind=task_result.get("failure_kind"),
            failure_stage=task_result.get("failure_stage"),
        )
    return None


def project_task_artifacts(
    artifacts: Mapping[str, Any],
    task_id: str,
    task_result: Mapping[str, Any],
) -> Dict[str, Any]:
    """Project shared symbolic-regression artifacts into a task-local spawn view."""
    projected = {
        "task_selector": task_id,
        "evaluation_stage": artifacts.get("evaluation_stage", "full"),
        "evaluation_mode": "task_specific",
        "selected_task_ids": [task_id],
        "task_results": [dict(task_result)],
    }
    if task_result.get("data_source_mode"):
        projected["data_source_modes"] = {task_id: task_result["data_source_mode"]}
    return projected


__all__ = [
    "DEFAULT_FULL_TIMEOUT_SECONDS",
    "DEFAULT_STAGE1_TIMEOUT_SECONDS",
    "SYMBOLIC_REGRESSION_PHYS_OSC_MAX_PARAMETER_BUDGET",
    "SYMBOLIC_REGRESSION_PHYS_OSC_METRIC_KEYS",
    "SYMBOLIC_REGRESSION_PHYS_OSC_SHARED_SELECTOR",
    "SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SELECTOR_ENV_VAR",
    "SYMBOLIC_REGRESSION_PHYS_OSC_TASK_SPECS",
    "SYMBOLIC_REGRESSION_PHYS_OSC_TASKS_BY_ID",
    "SymbolicRegressionPhysOscTaskSpec",
    "aggregate_task_results",
    "build_task_result",
    "empty_task_metrics",
    "extract_task_result",
    "normalize_task_metrics",
    "project_task_artifacts",
    "resolve_task_specs",
]
