#!/usr/bin/env python3
"""Custom ADK evaluator for the agent package in this folder.

This script runs evaluation directly with AgentEvaluator (Python API),
without using the `adk eval` CLI command.
"""

from __future__ import annotations

import argparse
import asyncio
from contextlib import aclosing
from datetime import datetime, timezone
from enum import Enum
import importlib
import json
import logging
import os
import sys
import time
from typing import Any
from pathlib import Path


SCRIPT_DIR = Path(__file__).resolve().parent
CODE_ROOT = SCRIPT_DIR.parent
DEFAULT_CONFIG = CODE_ROOT / "eval_config.json"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    datefmt="%Y-%m-%dT%H:%M:%S",
)
log = logging.getLogger("run_eval")

# Ensure local packages like `CanadaDRIPlanner` and `eval.*` resolve.
if str(CODE_ROOT) not in sys.path:
    sys.path.insert(0, str(CODE_ROOT))
    log.debug("Inserted %s into sys.path", CODE_ROOT)


try:
    from adk.evaluation import AgentEvaluator  # type: ignore
    from adk.evaluation.eval_config import get_evaluation_criteria_or_default  # type: ignore
    from adk.evaluation.eval_config import get_eval_metrics_from_config  # type: ignore
    from adk.evaluation.eval_set import EvalSet  # type: ignore
    from adk.evaluation.base_eval_service import EvaluateConfig  # type: ignore
    from adk.evaluation.base_eval_service import EvaluateRequest  # type: ignore
    from adk.evaluation.base_eval_service import InferenceConfig  # type: ignore
    from adk.evaluation.base_eval_service import InferenceRequest  # type: ignore
    from adk.evaluation.custom_metric_evaluator import _CustomMetricEvaluator  # type: ignore
    from adk.evaluation.local_eval_service import LocalEvalService  # type: ignore
    from adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager  # type: ignore
    from adk.evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY  # type: ignore
    from adk.evaluation.eval_metrics import Interval  # type: ignore
    from adk.evaluation.eval_metrics import MetricInfo  # type: ignore
    from adk.evaluation.eval_metrics import MetricValueInfo  # type: ignore
    from adk.evaluation.simulation.user_simulator_provider import UserSimulatorProvider  # type: ignore
    from adk.agents import Agent  # type: ignore
except ImportError:
    from google.adk.evaluation import AgentEvaluator
    from google.adk.evaluation.eval_config import get_evaluation_criteria_or_default
    from google.adk.evaluation.eval_config import get_eval_metrics_from_config
    from google.adk.evaluation.eval_set import EvalSet
    from google.adk.evaluation.base_eval_service import EvaluateConfig
    from google.adk.evaluation.base_eval_service import EvaluateRequest
    from google.adk.evaluation.base_eval_service import InferenceConfig
    from google.adk.evaluation.base_eval_service import InferenceRequest
    from google.adk.evaluation.custom_metric_evaluator import _CustomMetricEvaluator
    from google.adk.evaluation.local_eval_service import LocalEvalService
    from google.adk.evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
    from google.adk.evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
    from google.adk.evaluation.eval_metrics import Interval
    from google.adk.evaluation.eval_metrics import MetricInfo
    from google.adk.evaluation.eval_metrics import MetricValueInfo
    from google.adk.evaluation.simulation.user_simulator_provider import UserSimulatorProvider
    from google.adk.agents import Agent


def resolve_evalset(evalset_arg: str | None) -> Path:
    if evalset_arg:
        path = Path(evalset_arg)
        if not path.is_absolute():
            path = SCRIPT_DIR / path
        if not path.exists():
            raise FileNotFoundError(f"Eval set not found: {path}")
        resolved = path.resolve()
        log.info("Using specified eval set: %s", resolved)
        return resolved

    matches = sorted(SCRIPT_DIR.glob("*.evalset.json"))
    if not matches:
        raise FileNotFoundError(
            f"No eval set found in {SCRIPT_DIR}. Expected *.evalset.json"
        )
    resolved = matches[0].resolve()
    log.info("Auto-discovered eval set: %s", resolved)
    return resolved


def resolve_config(config_arg: str | None, no_config: bool) -> Path | None:
    if no_config:
        log.info("Config disabled via --no-config; using ADK defaults")
        return None

    if config_arg:
        path = Path(config_arg)
        if not path.is_absolute():
            path = CODE_ROOT / config_arg
        if not path.exists():
            raise FileNotFoundError(f"Config file not found: {path}")
        resolved = path.resolve()
        log.info("Using specified config: %s", resolved)
        return resolved

    if DEFAULT_CONFIG.exists():
        resolved = DEFAULT_CONFIG.resolve()
        log.info("Using default config: %s", resolved)
        return resolved

    log.info("No config file found; using ADK defaults")
    return None


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Run eval using ADK AgentEvaluator for local agent and evalset."
    )
    parser.add_argument(
        "--agent-module",
        default=SCRIPT_DIR.name,
        help=(
            "Python module containing the agent package. "
            "Default: current folder name."
        ),
    )
    parser.add_argument(
        "--agent-name",
        default=None,
        help="Optional sub-agent name. If omitted, evaluates root_agent.",
    )
    parser.add_argument(
        "--evalset",
        default=None,
        help="Path to evalset JSON. Default: first *.evalset.json in this folder.",
    )
    parser.add_argument(
        "--config",
        default=None,
        help="Config file path. Default: ../eval_config.json if present.",
    )
    parser.add_argument(
        "--no-config",
        action="store_true",
        help="Ignore config file and use ADK default criteria.",
    )
    parser.add_argument(
        "--num-runs",
        type=int,
        default=1,
        help="How many times to run each eval case (default: 1).",
    )
    parser.add_argument(
        "--sample-id",
        action="append",
        default=[],
        help=(
            "Run only specific eval sample ID(s), e.g. sample_00123. "
            "Can be repeated and can include comma-separated values."
        ),
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="Resume by skipping eval IDs already present in prior result files.",
    )
    parser.add_argument(
        "--resume-dir",
        action="append",
        default=None,
        help=(
            "Directory or file to scan for *.evalset_result.json when resuming. "
            "Can be repeated. Defaults to local .adk/.adk_old eval_history* folders."
        ),
    )
    parser.add_argument(
        "--resume-traces",
        action="store_true",
        help=(
            "Resume by skipping eval IDs that already have at least --num-runs trace file(s) "
            "in the trace directory. Mutually exclusive with --overwrite."
        ),
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help=(
            "Delete existing trace files for selected samples before running so every "
            "sample is re-evaluated from scratch. Mutually exclusive with --resume-traces."
        ),
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=20,
        help="Run eval in chunks of this many cases per batch.",
    )
    parser.add_argument(
        "--batch-index",
        type=int,
        default=None,
        help=(
            "1-based index of batch to run when --batch-size is set. "
            "If omitted, all batches run sequentially."
        ),
    )
    parser.add_argument(
        "--batch-concurrency",
        type=int,
        default=4,
        help=(
            "Maximum number of batches to run concurrently (default: 1). "
            "Higher values can reduce wall time when model calls are IO-bound."
        ),
    )
    parser.add_argument(
        "--no-detailed-results",
        action="store_true",
        help="Disable per-case detailed failure printing.",
    )
    parser.add_argument(
        "--trace-dir",
        default=".adk/traces",
        help=(
            "Directory to write per-eval trace JSON files. "
            "Relative paths are resolved from this script folder."
        ),
    )
    return parser


def _load_evalset_payload(evalset_path: Path) -> dict:
    log.debug("Loading eval set payload from %s", evalset_path)
    try:
        payload = json.loads(evalset_path.read_text(encoding="utf-8"))
    except json.JSONDecodeError as exc:
        log.error("Invalid evalset JSON in %s: %s", evalset_path, exc)
        raise ValueError(f"Invalid evalset JSON: {evalset_path} ({exc})") from exc
    if not isinstance(payload, dict):
        raise ValueError(f"Evalset root must be a JSON object: {evalset_path}")
    eval_cases = payload.get("eval_cases")
    if not isinstance(eval_cases, list):
        raise ValueError(f"Evalset missing list field 'eval_cases': {evalset_path}")
    log.info("Loaded eval set '%s' with %d cases", payload.get("eval_set_id", "<unnamed>"), len(eval_cases))
    return payload


def _parse_sample_ids(raw_ids: list[str]) -> set[str]:
    sample_ids: set[str] = set()
    for raw in raw_ids:
        for token in raw.split(","):
            sample_id = token.strip()
            if sample_id:
                sample_ids.add(sample_id)
    return sample_ids


def _resolve_resume_paths(raw_paths: list[str] | None) -> list[Path]:
    if raw_paths:
        resolved: list[Path] = []
        for raw_path in raw_paths:
            candidate = Path(raw_path)
            candidates = [candidate] if candidate.is_absolute() else [
                SCRIPT_DIR / candidate,
                CODE_ROOT / candidate,
            ]
            existing = next((p.resolve() for p in candidates if p.exists()), None)
            if existing is None:
                raise FileNotFoundError(
                    f"Resume path not found: {raw_path}. Tried: {', '.join(str(p) for p in candidates)}"
                )
            resolved.append(existing)
        return resolved

    defaults: list[Path] = []
    for pattern in (".adk/eval_history*", ".adk_old/eval_history*"):
        defaults.extend(p.resolve() for p in SCRIPT_DIR.glob(pattern) if p.is_dir())
    return sorted(set(defaults))


def _iter_result_files(paths: list[Path]):
    for path in paths:
        if path.is_file() and path.name.endswith(".evalset_result.json"):
            yield path
            continue
        if path.is_dir():
            for result_file in sorted(path.glob("*.evalset_result.json")):
                yield result_file


def _collect_completed_eval_ids(paths: list[Path], eval_set_id: str | None) -> tuple[set[str], int]:
    completed_ids: set[str] = set()
    scanned = 0
    log.debug("Scanning %d resume path(s) for completed eval IDs", len(paths))
    for result_file in _iter_result_files(paths):
        scanned += 1
        try:
            payload = json.loads(result_file.read_text(encoding="utf-8"))
        except json.JSONDecodeError:
            log.warning("Skipping unreadable result file: %s", result_file)
            continue
        if not isinstance(payload, dict):
            continue
        if eval_set_id and payload.get("eval_set_id") != eval_set_id:
            log.debug("Skipping result file with mismatched eval_set_id: %s", result_file)
            continue
        case_results = payload.get("eval_case_results")
        if not isinstance(case_results, list):
            continue
        for case_result in case_results:
            if not isinstance(case_result, dict):
                continue
            eval_id = case_result.get("eval_id")
            if isinstance(eval_id, str) and eval_id:
                completed_ids.add(eval_id)
    log.info("Resume: scanned %d result file(s), found %d already-completed eval ID(s)", scanned, len(completed_ids))
    return completed_ids, scanned


def _collect_completed_eval_ids_from_traces(trace_dir: Path, num_runs: int = 1) -> set[str]:
    """Return eval_ids with >= `num_runs` trace files already written in `trace_dir`."""
    from collections import Counter

    if not trace_dir.exists():
        log.info("Trace-based resume: trace dir does not exist yet: %s", trace_dir)
        return set()

    counts: Counter[str] = Counter()
    for trace_file in trace_dir.glob("*.trace.json"):
        try:
            data = json.loads(trace_file.read_text(encoding="utf-8"))
        except (json.JSONDecodeError, OSError):
            log.warning("Skipping unreadable trace file: %s", trace_file)
            continue
        phase = data.get("phase", "evaluation")
        if phase != "evaluation":
            continue
        eval_id = data.get("eval_id")
        if isinstance(eval_id, str) and eval_id:
            counts[eval_id] += 1

    completed = {eid for eid, count in counts.items() if count >= num_runs}
    log.info(
        "Trace-based resume: found %d completed eval ID(s) (num_runs>=%d) in %s",
        len(completed), num_runs, trace_dir,
    )
    return completed


def _delete_traces_for_eval_ids(trace_dir: Path, eval_ids: set[str]) -> int:
    """Delete all trace files belonging to `eval_ids`. Returns number of files deleted."""
    if not trace_dir.exists() or not eval_ids:
        return 0

    deleted = 0
    for trace_file in list(trace_dir.glob("*.trace.json")):
        try:
            data = json.loads(trace_file.read_text(encoding="utf-8"))
        except (json.JSONDecodeError, OSError):
            continue
        eval_id = data.get("eval_id")
        if isinstance(eval_id, str) and eval_id in eval_ids:
            trace_file.unlink()
            deleted += 1
            log.debug("Deleted trace: %s", trace_file)

    log.info(
        "Overwrite: deleted %d trace file(s) for %d eval ID(s)",
        deleted, len(eval_ids),
    )
    return deleted


def _filter_eval_cases(
    evalset_payload: dict,
    sample_ids: set[str],
    skip_eval_ids: set[str],
) -> tuple[list[dict], int]:
    all_cases = evalset_payload["eval_cases"]
    filtered_cases: list[dict] = []
    seen_eval_ids: set[str] = set()

    for case in all_cases:
        if not isinstance(case, dict):
            continue
        eval_id = case.get("eval_id")
        if not isinstance(eval_id, str) or not eval_id:
            continue
        seen_eval_ids.add(eval_id)

        if sample_ids and eval_id not in sample_ids:
            log.debug("Skipping eval_id '%s' (not in sample filter)", eval_id)
            continue
        if skip_eval_ids and eval_id in skip_eval_ids:
            log.debug("Skipping eval_id '%s' (already completed, resuming)", eval_id)
            continue
        filtered_cases.append(case)

    if sample_ids:
        missing = sorted(sample_ids - seen_eval_ids)
        if missing:
            raise ValueError(
                "Requested sample ID(s) not found in evalset: " + ", ".join(missing)
            )

    log.info("Filtered: %d/%d eval cases selected", len(filtered_cases), len(all_cases))
    return filtered_cases, len(all_cases)


def _build_eval_set(evalset_payload: dict, eval_cases: list[dict]) -> EvalSet:
    payload = dict(evalset_payload)
    payload["eval_cases"] = eval_cases
    return EvalSet.model_validate(payload)


def _chunk_cases(eval_cases: list[dict], batch_size: int) -> list[list[dict]]:
    return [eval_cases[i : i + batch_size] for i in range(0, len(eval_cases), batch_size)]


def _load_root_agent_if_available(agent_module: str):
    """Best-effort load to validate root_agent type for fast fail and clarity."""
    module = importlib.import_module(agent_module)
    module_with_agent = module.agent if hasattr(module, "agent") else module
    root_agent = getattr(module_with_agent, "root_agent", None)
    return root_agent


def _resolve_trace_dir(trace_dir_arg: str | None) -> Path:
    base = SCRIPT_DIR
    path = Path(trace_dir_arg) if trace_dir_arg else base / ".adk" / "traces"
    if not path.is_absolute():
        path = base / path
    path.mkdir(parents=True, exist_ok=True)
    return path.resolve()


def _safe_file_stem(value: str) -> str:
    return "".join(ch if ch.isalnum() or ch in {"_", "-"} else "_" for ch in value)


def _json_fallback(value: Any) -> Any:
    if isinstance(value, Enum):
        return value.value
    if isinstance(value, datetime):
        return value.isoformat()
    if isinstance(value, Path):
        return str(value)
    return str(value)


def _model_dump_json(value: Any) -> Any:
    model_dump = getattr(value, "model_dump", None)
    if callable(model_dump):
        try:
            return model_dump(mode="json")
        except TypeError:
            return model_dump()
    return value


def _write_json_payload(path: Path, payload: dict[str, Any]) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, default=_json_fallback)
        f.flush()
        os.fsync(f.fileno())


def _compute_percentile(values: list[float], percentile: float) -> float | None:
    if not values:
        return None
    if len(values) == 1:
        return values[0]

    ordered = sorted(values)
    rank = (len(ordered) - 1) * percentile
    lower_index = int(rank)
    upper_index = min(lower_index + 1, len(ordered) - 1)
    fraction = rank - lower_index
    lower_value = ordered[lower_index]
    upper_value = ordered[upper_index]
    return lower_value + (upper_value - lower_value) * fraction


def _build_duration_summary(durations_seconds: list[float]) -> dict[str, float] | None:
    if not durations_seconds:
        return None

    return {
        "count": float(len(durations_seconds)),
        "mean_seconds": sum(durations_seconds) / len(durations_seconds),
        "p50_seconds": _compute_percentile(durations_seconds, 0.50) or 0.0,
        "p95_seconds": _compute_percentile(durations_seconds, 0.95) or 0.0,
        "min_seconds": min(durations_seconds),
        "max_seconds": max(durations_seconds),
    }


def _print_run_timing_summary(run_number: int, num_runs: int, durations_seconds: list[float]) -> None:
    summary = _build_duration_summary(durations_seconds)
    if summary is None:
        return

    summary_line = (
        f"Run {run_number}/{num_runs} timing summary: "
        f"samples={int(summary['count'])}, "
        f"mean={summary['mean_seconds']:.2f}s, "
        f"p50={summary['p50_seconds']:.2f}s, "
        f"p95={summary['p95_seconds']:.2f}s, "
        f"min={summary['min_seconds']:.2f}s, "
        f"max={summary['max_seconds']:.2f}s"
    )
    print(summary_line)
    log.info(summary_line)


def _write_inference_trace(
    trace_dir: Path,
    inference_result: Any,
    run_number: int,
    sample_index: int,
    run_elapsed_seconds: float | None = None,
) -> None:
    timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ")
    raw_eval_id = getattr(inference_result, "eval_id", None)
    if isinstance(raw_eval_id, str) and raw_eval_id:
        eval_id = raw_eval_id
    else:
        eval_id = f"unknown_eval_id_{sample_index:05d}"

    trace_name = (
        f"{_safe_file_stem(eval_id)}.run_{run_number}.sample_{sample_index:05d}."
        f"{timestamp}.inference.trace.json"
    )
    trace_path = trace_dir / trace_name

    payload = {
        "created_at_utc": timestamp,
        "phase": "inference",
        "run_number": run_number,
        "sample_index": sample_index,
        "eval_id": eval_id,
        "eval_set_id": getattr(inference_result, "eval_set_id", None),
        "run_elapsed_seconds": run_elapsed_seconds,
        "inference_result": _model_dump_json(inference_result),
    }
    _write_json_payload(trace_path, payload)
    log.info("Inference trace saved [run %d, sample %d]: %s", run_number, sample_index, trace_path)


def _to_float(value: Any) -> float | None:
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        try:
            return float(value.replace(",", "").strip())
        except ValueError:
            return None
    return None


def _extract_timestamp_from_part_metadata(part: dict[str, Any]) -> float | None:
    metadata = part.get("part_metadata")
    if not isinstance(metadata, dict):
        return None
    for key in (
        "creation_timestamp",
        "timestamp",
        "created_at",
        "generated_at",
        "response_timestamp",
        "event_timestamp",
    ):
        timestamp = _to_float(metadata.get(key))
        if timestamp is not None:
            return timestamp
    return None


def _extract_eval_trace_timing_fields(eval_result: Any) -> dict[str, Any]:
    payload = _model_dump_json(eval_result)
    if not isinstance(payload, dict):
        return {}

    invocations = payload.get("eval_metric_result_per_invocation")
    if not isinstance(invocations, list) or not invocations:
        return {}

    first_invocation = invocations[0] if isinstance(invocations[0], dict) else {}
    actual = first_invocation.get("actual_invocation") if isinstance(first_invocation, dict) else {}
    if not isinstance(actual, dict):
        return {}

    user_message_timestamp = _to_float(actual.get("creation_timestamp"))

    final_message_timestamp = None
    final_response = actual.get("final_response")
    final_parts = final_response.get("parts") if isinstance(final_response, dict) else None
    if isinstance(final_parts, list):
        for part in reversed(final_parts):
            if isinstance(part, dict):
                ts = _extract_timestamp_from_part_metadata(part)
                if ts is not None:
                    final_message_timestamp = ts
                    break

    if final_message_timestamp is None:
        intermediate = actual.get("intermediate_data")
        events = intermediate.get("invocation_events") if isinstance(intermediate, dict) else None
        if isinstance(events, list):
            event_timestamps: list[float] = []
            for event in events:
                if not isinstance(event, dict):
                    continue
                content = event.get("content")
                parts = content.get("parts") if isinstance(content, dict) else None
                if isinstance(parts, list):
                    for part in parts:
                        if not isinstance(part, dict):
                            continue
                        function_response = part.get("function_response")
                        if not isinstance(function_response, dict):
                            continue
                        response_payload = function_response.get("response")
                        if not isinstance(response_payload, dict):
                            continue
                        response_ts = _to_float(response_payload.get("final_message_timestamp"))
                        if response_ts is None:
                            response_ts_utc = response_payload.get("final_message_timestamp_utc")
                            if isinstance(response_ts_utc, str):
                                try:
                                    response_ts = datetime.fromisoformat(response_ts_utc.replace("Z", "+00:00")).timestamp()
                                except ValueError:
                                    response_ts = None
                        if response_ts is not None:
                            event_timestamps.append(response_ts)

                event_ts = _to_float(
                    event.get("creation_timestamp")
                    or event.get("timestamp")
                    or event.get("event_timestamp")
                )
                if event_ts is not None:
                    event_timestamps.append(event_ts)
            if event_timestamps:
                final_message_timestamp = max(event_timestamps)

    latency_seconds = None
    if user_message_timestamp is not None and final_message_timestamp is not None:
        latency = final_message_timestamp - user_message_timestamp
        if latency >= 0:
            latency_seconds = latency

    return {
        "user_message_timestamp": user_message_timestamp,
        "final_message_timestamp": final_message_timestamp,
        "latency_seconds": latency_seconds,
    }


def _write_eval_trace(
    trace_dir: Path,
    eval_result: Any,
    run_number: int,
    run_elapsed_seconds: float | None = None,
    inference_completed_elapsed_seconds: float | None = None,
) -> None:
    timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S.%fZ")
    completion_timestamp = datetime.now(timezone.utc).timestamp()
    eval_id = getattr(eval_result, "eval_id", "unknown_eval_id")
    trace_name = f"{_safe_file_stem(eval_id)}.run_{run_number}.{timestamp}.trace.json"
    trace_path = trace_dir / trace_name

    status = str(getattr(eval_result, "final_eval_status", None))
    post_inference_evaluation_seconds = None
    if (
        run_elapsed_seconds is not None
        and inference_completed_elapsed_seconds is not None
    ):
        post_inference_evaluation_seconds = max(
            0.0,
            run_elapsed_seconds - inference_completed_elapsed_seconds,
        )

    timing_fields = _extract_eval_trace_timing_fields(eval_result)
    final_message_timestamp = timing_fields.get("final_message_timestamp")
    if final_message_timestamp is None:
        final_message_timestamp = completion_timestamp

    latency_seconds = timing_fields.get("latency_seconds")
    if latency_seconds is None:
        user_message_timestamp = timing_fields.get("user_message_timestamp")
        if user_message_timestamp is not None:
            latency_seconds = max(0.0, final_message_timestamp - user_message_timestamp)

    payload = {
        "created_at_utc": timestamp,
        "phase": "evaluation",
        "run_number": run_number,
        "eval_id": eval_id,
        "eval_set_id": getattr(eval_result, "eval_set_id", None),
        "final_eval_status": status,
        "run_elapsed_seconds": run_elapsed_seconds,
        "inference_completed_elapsed_seconds": inference_completed_elapsed_seconds,
        "post_inference_evaluation_seconds": post_inference_evaluation_seconds,
        "user_message_timestamp": timing_fields.get("user_message_timestamp"),
        "final_message_timestamp": final_message_timestamp,
        "latency_seconds": latency_seconds,
        "eval_case_result": _model_dump_json(eval_result),
    }
    # Write and flush immediately to ensure traces are saved to disk instantly
    _write_json_payload(trace_path, payload)
    log.info("Trace saved [run %d, status=%s]: %s", run_number, status, trace_path)


def _get_default_metric_info(metric_name: str, description: str = "") -> MetricInfo:
    return MetricInfo(
        metric_name=metric_name,
        description=description,
        metric_value_info=MetricValueInfo(
            interval=Interval(min_value=0.0, max_value=1.0)
        ),
    )


def _register_custom_metrics(eval_config: Any) -> None:
    custom_metrics = getattr(eval_config, "custom_metrics", None)
    if not custom_metrics:
        return

    for metric_name, config in custom_metrics.items():
        metric_info = None
        config_metric_info = getattr(config, "metric_info", None)
        if config_metric_info:
            metric_info = config_metric_info.model_copy()
            metric_info.metric_name = metric_name
        else:
            metric_info = _get_default_metric_info(
                metric_name=metric_name,
                description=getattr(config, "description", "") or "",
            )

        DEFAULT_METRIC_EVALUATOR_REGISTRY.register_evaluator(
            metric_info,
            _CustomMetricEvaluator,
        )


async def _evaluate_eval_set_with_streaming_traces(
    *,
    agent_module: str,
    agent_name: str | None,
    eval_set: EvalSet,
    eval_config: Any,
    num_runs: int,
    print_detailed_results: bool,
    trace_dir: Path,
) -> None:
    log.info("Starting evaluation: agent_module=%s, agent_name=%s, num_runs=%d", agent_module, agent_name, num_runs)
    eval_metrics = get_eval_metrics_from_config(eval_config)
    log.debug("Eval metrics: %s", [getattr(m, 'metric_name', m) for m in eval_metrics])
    agent_for_eval = await AgentEvaluator._get_agent_for_eval(
        module_name=agent_module,
        agent_name=agent_name,
    )
    log.info("Agent loaded: %s", agent_for_eval)

    app_name = agent_module
    eval_sets_manager = AgentEvaluator._get_eval_sets_manager(
        app_name=app_name,
        eval_set=eval_set,
    )
    user_simulator_provider = UserSimulatorProvider(
        user_simulator_config=eval_config.user_simulator_config
    )

    eval_service = LocalEvalService(
        root_agent=agent_for_eval,
        eval_sets_manager=eval_sets_manager,
        user_simulator_provider=user_simulator_provider,
        metric_evaluator_registry=DEFAULT_METRIC_EVALUATOR_REGISTRY,
        eval_set_results_manager=LocalEvalSetResultsManager(agents_dir=str(CODE_ROOT)),
    )

    eval_results_by_eval_id: dict[str, list[Any]] = {}
    for run_number in range(1, num_runs + 1):
        log.info("--- Run %d/%d ---", run_number, num_runs)
        run_started_at = time.perf_counter()
        inference_results = []
        inference_completed_elapsed_by_eval_id: dict[str, float] = {}
        eval_elapsed_by_eval_id: dict[str, float] = {}
        inference_request = InferenceRequest(
            app_name=app_name,
            eval_set_id=eval_set.eval_set_id,
            inference_config=InferenceConfig(),
        )
        log.debug("Performing inference for eval_set_id='%s'", eval_set.eval_set_id)
        async with aclosing(
            eval_service.perform_inference(inference_request=inference_request)
        ) as inference_stream:
            async for inference_result in inference_stream:
                inference_results.append(inference_result)
                sample_index = len(inference_results)
                run_elapsed_seconds = time.perf_counter() - run_started_at
                eval_id = getattr(inference_result, "eval_id", None)
                if isinstance(eval_id, str) and eval_id:
                    inference_completed_elapsed_by_eval_id[eval_id] = run_elapsed_seconds
                log.info(
                    "[run %d/%d] Inference done for sample '%s' (%d/%d so far)",
                    run_number, num_runs,
                    getattr(inference_result, 'eval_id', '?'),
                    sample_index,
                    len(eval_set.eval_cases),
                )
                _write_inference_trace(
                    trace_dir,
                    inference_result,
                    run_number,
                    sample_index,
                    run_elapsed_seconds,
                )

        log.info("Inference complete: %d result(s); running evaluation", len(inference_results))
        eval_case_ids = {
            case.eval_id
            for case in eval_set.eval_cases
            if isinstance(getattr(case, "eval_id", None), str)
        }
        filtered_inference_results = []
        dropped_inference = 0
        for inference_result in inference_results:
            eval_id = getattr(inference_result, "eval_id", None)
            if isinstance(eval_id, str) and eval_id in eval_case_ids:
                filtered_inference_results.append(inference_result)
            else:
                dropped_inference += 1
                log.error(
                    "Dropping malformed inference result with unknown eval_id=%r",
                    eval_id,
                )

        if dropped_inference:
            log.warning(
                "Dropped %d malformed inference result(s); evaluating remaining %d result(s)",
                dropped_inference,
                len(filtered_inference_results),
            )

        if not filtered_inference_results:
            log.error("No valid inference results to evaluate for run %d; skipping run", run_number)
            continue

        evaluate_request = EvaluateRequest(
            inference_results=filtered_inference_results,
            evaluate_config=EvaluateConfig(eval_metrics=eval_metrics),
        )
        async with aclosing(
            eval_service.evaluate(evaluate_request=evaluate_request)
        ) as eval_stream:
            async for eval_result in eval_stream:
                status = str(getattr(eval_result, 'final_eval_status', None))
                run_elapsed_seconds = time.perf_counter() - run_started_at
                inference_completed_elapsed_seconds = inference_completed_elapsed_by_eval_id.get(
                    eval_result.eval_id
                )
                eval_elapsed_by_eval_id[eval_result.eval_id] = run_elapsed_seconds
                log.info(
                    "[run %d/%d] Evaluated sample '%s' -> status=%s, e2e=%.2fs",
                    run_number, num_runs, eval_result.eval_id, status, run_elapsed_seconds,
                )
                _write_eval_trace(
                    trace_dir,
                    eval_result,
                    run_number,
                    run_elapsed_seconds,
                    inference_completed_elapsed_seconds,
                )
                if eval_result.eval_id not in eval_results_by_eval_id:
                    eval_results_by_eval_id[eval_result.eval_id] = []
                eval_results_by_eval_id[eval_result.eval_id].append(eval_result)

        run_eval_elapsed_seconds = list(eval_elapsed_by_eval_id.values())
        _print_run_timing_summary(run_number, num_runs, run_eval_elapsed_seconds)

    failures: list[str] = []
    evaluated_agent_label = agent_name or agent_module
    log.info("Aggregating results for %d unique eval ID(s)", len(eval_results_by_eval_id))
    for eval_id, eval_results_per_eval_id in eval_results_by_eval_id.items():
        log.debug("Processing metrics for eval_id='%s'", eval_id)
        eval_metric_results = AgentEvaluator._get_eval_metric_results_with_invocation(
            eval_results_per_eval_id
        )
        new_failures = AgentEvaluator._process_metrics_and_get_failures(
            eval_metric_results=eval_metric_results,
            print_detailed_results=print_detailed_results,
            agent_module=evaluated_agent_label,
        )
        if new_failures:
            log.warning("eval_id='%s' produced %d failure(s)", eval_id, len(new_failures))
        failures.extend(new_failures)

    if failures:
        log.error("%d total failure(s) across all eval cases", len(failures))
    else:
        log.info("All eval cases passed")

    failure_message = "Following are all the test failures."
    if not print_detailed_results:
        failure_message += (
            " If you looking to get more details on the failures, then please "
            "re-run this test with `print_detailed_results` set to `True`."
        )
    failure_message += "\n" + "\n".join(failures)
    assert not failures, failure_message


async def run_evaluation(args: argparse.Namespace) -> None:
    log.info("run_evaluation starting")
    evalset_path = resolve_evalset(args.evalset)
    config_path = resolve_config(args.config, args.no_config)
    evalset_payload = _load_evalset_payload(evalset_path)
    sample_ids = _parse_sample_ids(args.sample_id)

    # Resolve trace dir early so it can inform resume-traces filtering.
    trace_dir = _resolve_trace_dir(args.trace_dir)

    skip_eval_ids: set[str] = set()
    scanned_result_files = 0
    if args.resume:
        resume_paths = _resolve_resume_paths(args.resume_dir)
        skip_eval_ids, scanned_result_files = _collect_completed_eval_ids(
            resume_paths,
            eval_set_id=evalset_payload.get("eval_set_id"),
        )

    if args.resume_traces:
        trace_skip_ids = _collect_completed_eval_ids_from_traces(trace_dir, args.num_runs)
        skip_eval_ids |= trace_skip_ids

    filtered_cases, total_cases = _filter_eval_cases(
        evalset_payload=evalset_payload,
        sample_ids=sample_ids,
        skip_eval_ids=skip_eval_ids,
    )
    if not filtered_cases:
        log.warning("No eval cases to run after applying filters.")
        return

    if args.overwrite:
        selected_eval_ids = {
            case["eval_id"]
            for case in filtered_cases
            if isinstance(case.get("eval_id"), str)
        }
        deleted = _delete_traces_for_eval_ids(trace_dir, selected_eval_ids)
        print(f"Overwrite: deleted {deleted} existing trace file(s) for {len(selected_eval_ids)} sample(s)")

    eval_config = get_evaluation_criteria_or_default(
        str(config_path) if config_path else None
    )
    _register_custom_metrics(eval_config)

    # This import is only to verify the package wiring before eval starts.
    log.debug("Pre-loading root_agent from module '%s' for type validation", args.agent_module)
    root_agent = _load_root_agent_if_available(args.agent_module)
    if root_agent is not None and not isinstance(root_agent, Agent):
        raise TypeError(
            f"root_agent in module '{args.agent_module}' is not an ADK Agent instance"
        )
    log.debug("root_agent type check passed")

    print("Working directory:", CODE_ROOT)
    print("Agent module:", args.agent_module)
    print("Eval set:", evalset_path)
    print("Config:", config_path if config_path else "<ADK default>")
    print("Num runs:", args.num_runs)
    print("Total eval cases:", total_cases)
    if sample_ids:
        print("Sample filter count:", len(sample_ids))
    if args.resume:
        print("Resume scan files:", scanned_result_files)
        print("Resume skip count:", len(skip_eval_ids))
    if args.resume_traces:
        print("Resume-traces skip count:", len(skip_eval_ids))
    if args.overwrite:
        print("Overwrite: existing traces for selected samples deleted before run")
    print("Eval cases selected:", len(filtered_cases))
    print("Trace dir:", trace_dir)

    if args.batch_size:
        batches = _chunk_cases(filtered_cases, args.batch_size)
        if args.batch_index is not None:
            if args.batch_index < 1 or args.batch_index > len(batches):
                raise ValueError(
                    f"batch-index out of range: {args.batch_index}; valid range is 1..{len(batches)}"
                )
            batches_to_run = [(args.batch_index, batches[args.batch_index - 1])]
        else:
            batches_to_run = [(idx + 1, batch) for idx, batch in enumerate(batches)]

        log.info("Batch mode: batch_size=%d, total_batches=%d", args.batch_size, len(batches))
        if len(batches_to_run) > 1 and args.batch_concurrency > 1:
            log.info(
                "Running with batch concurrency=%d",
                min(args.batch_concurrency, len(batches_to_run)),
            )

        async def _run_batch(batch_number: int, batch_cases: list[dict]) -> tuple[int, str] | None:
            log.info("Running batch %d/%d (%d cases)", batch_number, len(batches), len(batch_cases))
            eval_set = _build_eval_set(evalset_payload, batch_cases)
            try:
                await _evaluate_eval_set_with_streaming_traces(
                    agent_module=args.agent_module,
                    agent_name=args.agent_name,
                    eval_set=eval_set,
                    eval_config=eval_config,
                    num_runs=args.num_runs,
                    print_detailed_results=not args.no_detailed_results,
                    trace_dir=trace_dir,
                )
                log.info("Batch %d/%d passed", batch_number, len(batches))
            except AssertionError as exc:
                log.error("Batch %d/%d failed thresholds: %s", batch_number, len(batches), exc)
                return batch_number, str(exc)
            return None

        batch_failures: list[tuple[int, str]] = []
        batch_concurrency = min(args.batch_concurrency, len(batches_to_run))

        if batch_concurrency <= 1 or len(batches_to_run) <= 1:
            for batch_number, batch_cases in batches_to_run:
                maybe_failure = await _run_batch(batch_number, batch_cases)
                if maybe_failure is not None:
                    batch_failures.append(maybe_failure)
        else:
            semaphore = asyncio.Semaphore(batch_concurrency)

            async def _run_batch_with_limit(
                batch_number: int,
                batch_cases: list[dict],
            ) -> tuple[int, str] | None:
                async with semaphore:
                    return await _run_batch(batch_number, batch_cases)

            tasks = [
                asyncio.create_task(_run_batch_with_limit(batch_number, batch_cases))
                for batch_number, batch_cases in batches_to_run
            ]
            for task in asyncio.as_completed(tasks):
                maybe_failure = await task
                if maybe_failure is not None:
                    batch_failures.append(maybe_failure)

        if batch_failures:
            failed_batches = ", ".join(str(batch_no) for batch_no, _ in batch_failures)
            raise AssertionError(f"One or more batches failed thresholds. Failed batches: {failed_batches}")
        return

    eval_set = _build_eval_set(evalset_payload, filtered_cases)

    await _evaluate_eval_set_with_streaming_traces(
        agent_module=args.agent_module,
        agent_name=args.agent_name,
        eval_set=eval_set,
        eval_config=eval_config,
        num_runs=args.num_runs,
        print_detailed_results=not args.no_detailed_results,
        trace_dir=trace_dir,
    )


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()

    if args.num_runs < 1:
        parser.error("--num-runs must be >= 1")
    if args.batch_size is not None and args.batch_size < 1:
        parser.error("--batch-size must be >= 1")
    if args.batch_index is not None and args.batch_size is None:
        parser.error("--batch-index requires --batch-size")
    if args.batch_index is not None and args.batch_index < 1:
        parser.error("--batch-index must be >= 1")
    if args.batch_concurrency < 1:
        parser.error("--batch-concurrency must be >= 1")
    if args.resume_traces and args.overwrite:
        parser.error("--resume-traces and --overwrite are mutually exclusive")

    try:
        asyncio.run(run_evaluation(args))
    except AssertionError as exc:
        # AgentEvaluator raises AssertionError when metrics fail thresholds.
        log.error("Evaluation failed thresholds: %s", exc)
        return 1
    except Exception as exc:  # pragma: no cover
        log.exception("Evaluation failed with unexpected error: %s", exc)
        return 1

    log.info("Evaluation completed successfully")
    return 0


if __name__ == "__main__":
    import time
    print("ADK Nutrition Agent LLM Optimizer - Evaluation Runner")
    start_time = time.time()
    main()
    end_time = time.time()
    print(f"Done. Total time: {end_time - start_time:.2f} seconds")
