"""Evaluator for the MT-STS rectangle circle-packing family."""

from __future__ import annotations

import json
import os
from pathlib import Path
import subprocess
import sys
import tempfile
import textwrap
import time
import traceback
from typing import Any, Dict, Iterable, Mapping, Sequence

import numpy as np

REPO_ROOT = Path(__file__).resolve().parents[2]
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from openevolve.evaluation_result import EvaluationResult

try:
    from openevolve.multi_task_shared_then_specialize.circle_packing_rectangle import (
        CIRCLE_PACKING_RECTANGLE_SHARED_SELECTOR,
        CIRCLE_PACKING_RECTANGLE_TASK_SELECTOR_ENV_VAR,
        CirclePackingRectangleTaskSpec,
        aggregate_task_results,
        build_task_result,
        resolve_eval_task_specs,
    )
except ImportError:
    # Older long-lived worker processes may still expose a family module shape
    # without the newer evaluator-only selector helper. For in-distribution
    # reevaluation, falling back to resolve_task_specs is equivalent.
    from openevolve.multi_task_shared_then_specialize.circle_packing_rectangle import (
        CIRCLE_PACKING_RECTANGLE_SHARED_SELECTOR,
        CIRCLE_PACKING_RECTANGLE_TASK_SELECTOR_ENV_VAR,
        CirclePackingRectangleTaskSpec,
        aggregate_task_results,
        build_task_result,
        resolve_task_specs as resolve_eval_task_specs,
    )


BOUNDARY_TOLERANCE = 1.0e-6
MAX_VALIDATION_MESSAGES = 5
MAX_TEXT_EXCERPT = 400
OOD_TIMEOUT_OVERRIDE_ENV_VAR = "MT_STS_OOD_TIMEOUT_OVERRIDE_SECONDS"


def _truncate_text(value: str | None, *, limit: int = MAX_TEXT_EXCERPT) -> str:
    text = (value or "").strip()
    if len(text) <= limit:
        return text
    return text[-limit:]


def _apply_ood_timeout_override(timeout_seconds: float) -> float:
    raw_override = os.environ.get(OOD_TIMEOUT_OVERRIDE_ENV_VAR)
    if raw_override is None:
        return float(timeout_seconds)
    try:
        override_seconds = float(raw_override)
    except (TypeError, ValueError):
        return float(timeout_seconds)
    if override_seconds <= 0.0:
        return float(timeout_seconds)
    return max(float(timeout_seconds), override_seconds)


def _append_validation_message(validation_details: dict[str, Any], message: str) -> None:
    example_messages = validation_details.setdefault("example_messages", [])
    if len(example_messages) < MAX_VALIDATION_MESSAGES:
        example_messages.append(str(message))


def _empty_validation_details() -> dict[str, Any]:
    return {
        "alpha_valid": True,
        "shape_message": None,
        "boundary_violations": 0,
        "overlap_violations": 0,
        "negative_radius_violations": 0,
        "non_finite_value_violations": 0,
        "example_messages": [],
    }


def _validation_summary(
    validation_details: Mapping[str, Any],
    *,
    sum_mismatch: bool,
) -> dict[str, Any]:
    summary = {
        "alpha_valid": bool(validation_details.get("alpha_valid", False)),
        "boundary_violations": int(validation_details.get("boundary_violations", 0) or 0),
        "overlap_violations": int(validation_details.get("overlap_violations", 0) or 0),
        "negative_radius_violations": int(
            validation_details.get("negative_radius_violations", 0) or 0
        ),
        "non_finite_value_violations": int(
            validation_details.get("non_finite_value_violations", 0) or 0
        ),
        "sum_mismatch": bool(sum_mismatch),
    }
    shape_message = validation_details.get("shape_message")
    if isinstance(shape_message, str) and shape_message:
        summary["shape_message"] = shape_message
    example_messages = validation_details.get("example_messages")
    if isinstance(example_messages, Sequence) and not isinstance(example_messages, (str, bytes)):
        compact_messages = [str(message) for message in example_messages[:MAX_VALIDATION_MESSAGES]]
        if compact_messages:
            summary["example_messages"] = compact_messages
    return summary


def validate_rectangle_packing(
    centers: np.ndarray,
    radii: np.ndarray,
    alpha: float,
    *,
    n_expected: int | None = None,
) -> tuple[bool, dict[str, Any]]:
    """Validate rectangle containment, alpha bounds, and no-overlap constraints."""
    validation_details = _empty_validation_details()

    try:
        alpha_value = float(alpha)
    except (TypeError, ValueError):
        validation_details["alpha_valid"] = False
        _append_validation_message(validation_details, f"alpha is not numeric: {alpha!r}")
        return False, validation_details

    if not np.isfinite(alpha_value) or not (0.0 < alpha_value <= 1.0):
        validation_details["alpha_valid"] = False
        _append_validation_message(
            validation_details,
            f"alpha must satisfy 0 < alpha <= 1, got {alpha_value!r}",
        )
        return False, validation_details

    height = 2.0 - alpha_value
    if not np.isfinite(height) or height <= 0.0:
        validation_details["alpha_valid"] = False
        _append_validation_message(
            validation_details,
            f"height must be positive and finite, got {height!r}",
        )
        return False, validation_details

    if n_expected is not None and centers.shape != (n_expected, 2):
        validation_details["shape_message"] = (
            f"centers must have shape ({n_expected}, 2), got {tuple(centers.shape)}"
        )
        _append_validation_message(validation_details, validation_details["shape_message"])
        return False, validation_details
    if n_expected is not None and radii.shape != (n_expected,):
        validation_details["shape_message"] = (
            f"radii must have shape ({n_expected},), got {tuple(radii.shape)}"
        )
        _append_validation_message(validation_details, validation_details["shape_message"])
        return False, validation_details
    if centers.ndim != 2 or centers.shape[1] != 2:
        validation_details["shape_message"] = (
            f"centers must have shape (n, 2), got {tuple(centers.shape)}"
        )
        _append_validation_message(validation_details, validation_details["shape_message"])
        return False, validation_details
    if radii.ndim != 1:
        validation_details["shape_message"] = (
            f"radii must have shape (n,), got {tuple(radii.shape)}"
        )
        _append_validation_message(validation_details, validation_details["shape_message"])
        return False, validation_details

    non_finite_values = int(np.size(centers) - int(np.isfinite(centers).sum()))
    non_finite_values += int(np.size(radii) - int(np.isfinite(radii).sum()))
    validation_details["non_finite_value_violations"] = non_finite_values
    if non_finite_values > 0:
        _append_validation_message(
            validation_details,
            f"Packing contains {non_finite_values} non-finite center/radius values",
        )
        return False, validation_details

    negative_count = int(np.sum(radii < 0.0))
    validation_details["negative_radius_violations"] = negative_count
    if negative_count > 0:
        _append_validation_message(
            validation_details,
            f"Packing contains {negative_count} negative radii",
        )

    n_circles = radii.shape[0]
    for index, (center, radius) in enumerate(zip(centers, radii)):
        x, y = float(center[0]), float(center[1])
        r = float(radius)
        if (
            x < r - BOUNDARY_TOLERANCE
            or x > alpha_value - r + BOUNDARY_TOLERANCE
            or y < r - BOUNDARY_TOLERANCE
            or y > height - r + BOUNDARY_TOLERANCE
        ):
            validation_details["boundary_violations"] += 1
            _append_validation_message(
                validation_details,
                (
                    f"Circle {index} at ({x:.6f}, {y:.6f}) with radius {r:.6f} "
                    f"violates rectangle bounds [0,{alpha_value:.6f}] x [0,{height:.6f}]"
                ),
            )

    for i in range(n_circles):
        for j in range(i + 1, n_circles):
            distance = float(np.linalg.norm(centers[i] - centers[j]))
            min_distance = float(radii[i] + radii[j])
            if distance < min_distance - BOUNDARY_TOLERANCE:
                validation_details["overlap_violations"] += 1
                _append_validation_message(
                    validation_details,
                    (
                        f"Circles {i} and {j} overlap: distance={distance:.6f}, "
                        f"required={min_distance:.6f}"
                    ),
                )

    is_valid = (
        validation_details["alpha_valid"]
        and validation_details["shape_message"] is None
        and validation_details["negative_radius_violations"] == 0
        and validation_details["non_finite_value_violations"] == 0
        and validation_details["boundary_violations"] == 0
        and validation_details["overlap_violations"] == 0
    )
    return is_valid, validation_details


def _subprocess_runner_script(
    *,
    program_path: str,
    result_path: str,
    n_circles: int,
) -> str:
    return textwrap.dedent(
        f"""
        import importlib.util
        import json
        import os
        import sys
        import traceback
        import numpy as np

        PROGRAM_PATH = {program_path!r}
        RESULT_PATH = {result_path!r}
        N_CIRCLES = {int(n_circles)}

        payload = None
        try:
            sys.path.insert(0, os.path.dirname(PROGRAM_PATH))
            module_name = "circle_packing_rectangle_mt_sts_candidate"
            spec = importlib.util.spec_from_file_location(module_name, PROGRAM_PATH)
            if spec is None or spec.loader is None:
                raise ImportError(f"Could not load program from {{PROGRAM_PATH}}")
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)

            runner = getattr(module, "run_packing", None)
            if not callable(runner):
                runner = getattr(module, "construct_packing", None)
            if not callable(runner):
                raise AttributeError(
                    "Program must define run_packing(n) or construct_packing(n)"
                )

            result = runner(N_CIRCLES)
            if not isinstance(result, (tuple, list)):
                raise TypeError(
                    f"Expected tuple/list return value, got {{type(result).__name__}}"
                )

            if len(result) == 3:
                centers, radii, alpha = result
                reported_sum_radii = None
            elif len(result) == 4:
                centers, radii, alpha, reported_sum_radii = result
                if reported_sum_radii is not None:
                    reported_sum_radii = float(reported_sum_radii)
            else:
                raise ValueError(
                    "Expected (centers, radii, alpha) or "
                    "(centers, radii, alpha, sum_radii)"
                )

            centers_array = np.asarray(centers, dtype=float)
            radii_array = np.asarray(radii, dtype=float)
            alpha_value = float(alpha)
            payload = {{
                "ok": True,
                "centers": centers_array.tolist(),
                "radii": radii_array.tolist(),
                "alpha": alpha_value,
                "reported_sum_radii": reported_sum_radii,
            }}
        except Exception as exc:
            payload = {{
                "ok": False,
                "error": f"{{type(exc).__name__}}: {{exc}}",
                "traceback": traceback.format_exc(),
            }}

        with open(RESULT_PATH, "w", encoding="utf-8") as handle:
            json.dump(payload, handle)
        """
    ).strip()


def _run_program_in_subprocess(
    program_path: str,
    task: CirclePackingRectangleTaskSpec,
    *,
    timeout_seconds: float,
) -> dict[str, Any]:
    start_time = time.perf_counter()
    with tempfile.TemporaryDirectory(prefix="circle_packing_rectangle_mt_sts_") as temp_dir:
        temp_root = Path(temp_dir)
        runner_path = temp_root / "runner.py"
        result_path = temp_root / "result.json"
        runner_path.write_text(
            _subprocess_runner_script(
                program_path=str(Path(program_path).resolve()),
                result_path=str(result_path),
                n_circles=task.n_circles,
            ),
            encoding="utf-8",
        )

        process = subprocess.Popen(
            [sys.executable, str(runner_path)],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )

        timed_out = False
        try:
            stdout, stderr = process.communicate(timeout=float(timeout_seconds))
        except subprocess.TimeoutExpired:
            timed_out = True
            process.kill()
            stdout, stderr = process.communicate()

        eval_time = time.perf_counter() - start_time
        payload = None
        if result_path.is_file():
            try:
                payload = json.loads(result_path.read_text(encoding="utf-8"))
            except Exception:
                payload = None

    return {
        "timed_out": timed_out,
        "exit_code": process.returncode,
        "stdout_excerpt": _truncate_text(stdout),
        "stderr_excerpt": _truncate_text(stderr),
        "eval_time": float(eval_time),
        "payload": payload,
    }


def _task_metrics(
    task: CirclePackingRectangleTaskSpec,
    *,
    centers: np.ndarray,
    radii: np.ndarray,
    alpha: float,
    eval_time: float,
    valid: bool,
) -> dict[str, float]:
    if not valid:
        return {
            "sum_radii": 0.0,
            "target_sum_radii": float(task.target_sum_radii),
            "target_ratio": 0.0,
            "validity": 0.0,
            "alpha": 0.0,
            "height": 0.0,
            "aspect_ratio": 0.0,
            "radius_variance": 0.0,
            "spatial_spread": 0.0,
            "min_radius": 0.0,
            "max_radius": 0.0,
            "eval_time": float(max(0.0, eval_time)),
            "score": 0.0,
            "combined_score": 0.0,
        }

    alpha_value = float(alpha)
    height = 2.0 - alpha_value
    sum_radii = float(np.sum(radii))
    target_ratio = sum_radii / float(task.target_sum_radii)
    centroid = np.mean(centers, axis=0)
    distances_from_centroid = np.linalg.norm(centers - centroid, axis=1)
    radius_variance = float(np.var(radii) / 0.0625)
    spatial_spread = float(np.std(distances_from_centroid) / (0.5 * np.sqrt(5.0)))
    score = target_ratio
    return {
        "sum_radii": sum_radii,
        "target_sum_radii": float(task.target_sum_radii),
        "target_ratio": target_ratio,
        "validity": 1.0,
        "alpha": alpha_value,
        "height": height,
        "aspect_ratio": float(alpha_value / height) if height > 0.0 else 0.0,
        "radius_variance": float(np.clip(radius_variance, 0.0, 1.0)),
        "spatial_spread": float(np.clip(spatial_spread, 0.0, 1.0)),
        "min_radius": float(np.min(radii)),
        "max_radius": float(np.max(radii)),
        "eval_time": float(max(0.0, eval_time)),
        "score": score,
        "combined_score": score,
    }


def evaluate_one_task(
    program_path: str,
    task: CirclePackingRectangleTaskSpec,
    *,
    stage1: bool = False,
) -> tuple[dict[str, Any], dict[str, Any]]:
    timeout_seconds = (
        float(task.timeout_seconds_stage1) if stage1 else float(task.timeout_seconds_full)
    )
    timeout_seconds = _apply_ood_timeout_override(timeout_seconds)
    execution = _run_program_in_subprocess(
        program_path,
        task,
        timeout_seconds=timeout_seconds,
    )
    validation_details = _empty_validation_details()
    sum_mismatch = False
    error = None

    if execution["timed_out"]:
        error = f"Timed out after {timeout_seconds:.0f}s"
        metrics = _task_metrics(
            task,
            centers=np.zeros((task.n_circles, 2), dtype=float),
            radii=np.zeros(task.n_circles, dtype=float),
            alpha=0.0,
            eval_time=float(execution["eval_time"]),
            valid=False,
        )
    else:
        payload = execution.get("payload")
        if not isinstance(payload, Mapping):
            error = "Missing subprocess result payload"
            metrics = _task_metrics(
                task,
                centers=np.zeros((task.n_circles, 2), dtype=float),
                radii=np.zeros(task.n_circles, dtype=float),
                alpha=0.0,
                eval_time=float(execution["eval_time"]),
                valid=False,
            )
        elif not bool(payload.get("ok", False)):
            error = str(payload.get("error") or "Candidate execution failed")
            metrics = _task_metrics(
                task,
                centers=np.zeros((task.n_circles, 2), dtype=float),
                radii=np.zeros(task.n_circles, dtype=float),
                alpha=0.0,
                eval_time=float(execution["eval_time"]),
                valid=False,
            )
        else:
            try:
                centers = np.asarray(payload.get("centers"), dtype=float)
                radii = np.asarray(payload.get("radii"), dtype=float)
                alpha = float(payload.get("alpha"))
            except Exception as exc:
                validation_details["shape_message"] = (
                    f"Could not convert centers/radii/alpha to numeric values: {exc}"
                )
                _append_validation_message(
                    validation_details,
                    validation_details["shape_message"],
                )
                centers = np.zeros((task.n_circles, 2), dtype=float)
                radii = np.zeros(task.n_circles, dtype=float)
                alpha = 0.0
                error = validation_details["shape_message"]
                metrics = _task_metrics(
                    task,
                    centers=centers,
                    radii=radii,
                    alpha=alpha,
                    eval_time=float(execution["eval_time"]),
                    valid=False,
                )
            else:
                valid, validation_details = validate_rectangle_packing(
                    centers,
                    radii,
                    alpha,
                    n_expected=task.n_circles,
                )
                metrics = _task_metrics(
                    task,
                    centers=centers,
                    radii=radii,
                    alpha=alpha,
                    eval_time=float(execution["eval_time"]),
                    valid=valid,
                )
                reported_sum_radii = payload.get("reported_sum_radii")
                if reported_sum_radii is not None and valid:
                    try:
                        reported_sum_radii = float(reported_sum_radii)
                    except (TypeError, ValueError):
                        sum_mismatch = True
                    else:
                        sum_mismatch = abs(reported_sum_radii - float(np.sum(radii))) > 1.0e-9

                if not valid:
                    error = str(validation_details.get("shape_message") or "")
                    if not error:
                        example_messages = validation_details.get("example_messages") or []
                        if example_messages:
                            error = str(example_messages[0])
                    if not error:
                        error = "Packing validation failed"

    validation_summary = _validation_summary(
        validation_details,
        sum_mismatch=sum_mismatch,
    )
    task_result = build_task_result(
        task,
        raw_metrics=metrics,
        error=error,
        validation_summary=validation_summary,
    )

    task_artifacts = {
        "task_id": task.task_id,
        "timed_out": bool(execution["timed_out"]),
        "validation_summary": validation_summary,
        "execution_summary": {
            "status": "success" if error is None else "error",
            "timed_out": bool(execution["timed_out"]),
            "exit_code": execution.get("exit_code"),
            "eval_time": float(execution["eval_time"]),
        },
    }
    if error is not None:
        stdout_excerpt = str(execution.get("stdout_excerpt") or "")
        stderr_excerpt = str(execution.get("stderr_excerpt") or "")
        if stdout_excerpt:
            task_artifacts["execution_summary"]["stdout_excerpt"] = stdout_excerpt
        if stderr_excerpt:
            task_artifacts["execution_summary"]["stderr_excerpt"] = stderr_excerpt
        payload = execution.get("payload")
        if isinstance(payload, Mapping):
            traceback_text = _truncate_text(str(payload.get("traceback") or ""))
            if traceback_text:
                task_artifacts["execution_summary"]["traceback_excerpt"] = traceback_text
    task_artifacts["compact_task_summary"] = {
        "score": float(task_result["final_task_score"]),
        "sum_radii": float(task_result["metrics"]["sum_radii"]),
        "target_ratio": float(task_result["metrics"]["target_ratio"]),
        "validity": float(task_result["metrics"]["validity"]),
        "alpha": float(task_result["metrics"]["alpha"]),
        "height": float(task_result["metrics"]["height"]),
        "aspect_ratio": float(task_result["metrics"]["aspect_ratio"]),
        "timed_out": bool(execution["timed_out"]),
    }
    return task_result, task_artifacts


def _public_artifacts(
    *,
    selector: str,
    stage_name: str,
    task_results: list[dict[str, Any]],
    per_task_artifacts: Iterable[Mapping[str, Any]],
) -> dict[str, Any]:
    task_artifact_list = list(per_task_artifacts)
    artifacts: dict[str, Any] = {
        "task_selector": selector,
        "selected_task_ids": [task_result["task_id"] for task_result in task_results],
        "evaluation_mode": "shared" if len(task_results) > 1 else "task_specific",
        "evaluation_stage": stage_name,
        "task_results": task_results,
        "subprocess_timeout_by_task": {},
        "validation_summaries": {},
        "compact_task_summary": {},
        "execution_summary": {
            "selected_task_count": len(task_results),
            "timed_out_task_count": 0,
            "successful_task_count": int(sum(1 for result in task_results if not result.get("error"))),
            "failed_task_count": int(sum(1 for result in task_results if result.get("error"))),
        },
    }

    for task_artifacts in task_artifact_list:
        task_id = str(task_artifacts["task_id"])
        timed_out = bool(task_artifacts.get("timed_out", False))
        artifacts["subprocess_timeout_by_task"][task_id] = timed_out
        artifacts["validation_summaries"][task_id] = dict(task_artifacts["validation_summary"])
        artifacts["compact_task_summary"][task_id] = dict(task_artifacts["compact_task_summary"])
        if timed_out:
            artifacts["execution_summary"]["timed_out_task_count"] += 1

    if task_results:
        best_task_result = max(
            task_results,
            key=lambda result: float(result["final_task_score"]),
        )
        artifacts["best_task_summary"] = {
            "task_id": best_task_result["task_id"],
            "score": float(best_task_result["final_task_score"]),
            "sum_radii": float(best_task_result["metrics"]["sum_radii"]),
            "target_ratio": float(best_task_result["metrics"]["target_ratio"]),
            "alpha": float(best_task_result["metrics"]["alpha"]),
        }

    return artifacts


def _evaluate(program_path: str, *, stage1: bool) -> EvaluationResult:
    selector = os.environ.get(
        CIRCLE_PACKING_RECTANGLE_TASK_SELECTOR_ENV_VAR,
        CIRCLE_PACKING_RECTANGLE_SHARED_SELECTOR,
    )
    selected_tasks = resolve_eval_task_specs(selector)
    stage_name = "stage1" if stage1 else "full"

    task_results: list[dict[str, Any]] = []
    task_artifacts: list[dict[str, Any]] = []
    for task in selected_tasks:
        task_result, per_task_artifacts = evaluate_one_task(
            program_path,
            task,
            stage1=stage1,
        )
        task_results.append(task_result)
        task_artifacts.append(per_task_artifacts)

    artifacts = _public_artifacts(
        selector=selector,
        stage_name=stage_name,
        task_results=task_results,
        per_task_artifacts=task_artifacts,
    )
    if len(task_results) == 1:
        return EvaluationResult(metrics=dict(task_results[0]["metrics"]), artifacts=artifacts)
    return EvaluationResult(metrics=aggregate_task_results(task_results), artifacts=artifacts)


def evaluate(program_path: str) -> EvaluationResult:
    """Evaluate one task or the shared rectangle circle-packing family."""
    return _evaluate(program_path, stage1=False)


def evaluate_stage1(program_path: str) -> EvaluationResult:
    """Cheaper cascade stage that uses the lighter per-task timeout."""
    return _evaluate(program_path, stage1=True)


def evaluate_stage2(program_path: str) -> EvaluationResult:
    """Full evaluation for cascade mode."""
    return evaluate(program_path)


if __name__ == "__main__":
    if len(sys.argv) <= 1:
        raise SystemExit("Usage: python evaluator.py <program_path>")

    try:
        evaluation_result = evaluate(sys.argv[1])
    except Exception:
        print(traceback.format_exc())
        raise

    print(f"Score: {evaluation_result.metrics['score']:.6f}")
    print(f"Combined Score: {evaluation_result.metrics['combined_score']:.6f}")
    if "task_count" in evaluation_result.metrics:
        print(f"Task Count: {int(evaluation_result.metrics['task_count'])}")
    for task_result in evaluation_result.artifacts.get("task_results", []):
        print(
            f"{task_result['task_id']}: "
            f"score={task_result['final_task_score']:.6f} "
            f"sum_radii={task_result['metrics']['sum_radii']:.6f} "
            f"alpha={task_result['metrics']['alpha']:.6f}"
        )
