#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any


SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_HISTORY_DIR = SCRIPT_DIR / ".adk" / "eval_history"
CALORIE_KEYS = (
    "calories eer_kcal",
    "calories_eer_kcal",
    "eer_kcal",
    "calories",
    "energy_kcal",
    "energy",
)
OPTIMIZER_ACHIEVED_WRAPPER_KEY = "achieved_targets_per_day"
BEST_EFFORT_WRAPPER_KEY = "best_effort_achieved_targets_per_day"
CUISINE_METRIC_NAMES = (
    "per_day_cuisine_alignment_score",
    "cuisine_alignment_score",
)
PALATABILITY_METRIC_NAMES = (
    "per_day_palatability_score",
    "palatability_score",
)
INFERENCE_TRACE_PATTERN = "*.inference.trace.json"


def _to_float(value: Any) -> float | None:
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        try:
            return float(value.replace(",", "").strip())
        except ValueError:
            return None
    return None


def _iter_invocation_events(case_result: dict[str, Any]) -> list[dict[str, Any]]:
    invocations = case_result.get("eval_metric_result_per_invocation") or []
    if not isinstance(invocations, list):
        return []

    events: list[dict[str, Any]] = []
    for invocation in invocations:
        if not isinstance(invocation, dict):
            continue
        actual_invocation = invocation.get("actual_invocation") or {}
        if not isinstance(actual_invocation, dict):
            continue

        intermediate_data = actual_invocation.get("intermediate_data") or {}
        if not isinstance(intermediate_data, dict):
            continue

        invocation_events = intermediate_data.get("invocation_events")
        if isinstance(invocation_events, list):
            events.extend(event for event in invocation_events if isinstance(event, dict))

    return events


def _unwrap_result(payload: dict[str, Any]) -> dict[str, Any]:
    current: dict[str, Any] = payload
    # Some ADK payloads nest result objects multiple levels deep.
    for _ in range(8):
        nested = current.get("result")
        if isinstance(nested, dict):
            current = nested
            continue
        break
    return current


def _extract_function_responses(case_result: dict[str, Any], function_name: str) -> list[dict[str, Any]]:
    responses: list[dict[str, Any]] = []
    for event in _iter_invocation_events(case_result):
        content = event.get("content") if isinstance(event, dict) else None
        parts = content.get("parts") if isinstance(content, dict) else None
        if not isinstance(parts, list):
            continue

        for part in parts:
            if not isinstance(part, dict):
                continue
            function_response = part.get("function_response")
            if not isinstance(function_response, dict):
                continue
            if function_response.get("name") != function_name:
                continue

            payload = function_response.get("response")
            if isinstance(payload, dict):
                responses.append(payload)
    return responses


def _extract_target_calories(case_result: dict[str, Any]) -> float | None:
    dri_responses = _extract_function_responses(case_result, "calculate_health_canada_dri")
    if not dri_responses:
        return None

    for response in reversed(dri_responses):
        dri_payload = _unwrap_result(response)

        calories = dri_payload.get("calories")
        if isinstance(calories, dict):
            eer = _to_float(calories.get("eer_kcal"))
            if eer is not None:
                return eer
        else:
            calorie_value = _to_float(calories)
            if calorie_value is not None:
                return calorie_value

        recommended_per_day = dri_payload.get("recommended_g_per_day")
        if isinstance(recommended_per_day, dict):
            for key in CALORIE_KEYS:
                calorie_value = _to_float(recommended_per_day.get(key))
                if calorie_value is not None:
                    return calorie_value

    return None


def _extract_metric_values(payload: dict[str, Any]) -> dict[str, float]:
    values: dict[str, float] = {}
    for key in ("calories", "total_fibre", "protein", "carbohydrates", "total_fat"):
        value = _to_float(payload.get(key))
        if value is not None:
            values[key] = value
    return values


def _extract_metrics_from_payload(payload: dict[str, Any], *, wrapper_key: str | None) -> dict[str, float]:
    payload = _unwrap_result(payload)

    if wrapper_key is not None:
        nested = payload.get(wrapper_key)
        if isinstance(nested, dict):
            values = _extract_metric_values(nested)
            if values:
                return values

    return {}


def _extract_achieved_calories(case_result: dict[str, Any]) -> float | None:
    optimize_responses = _extract_function_responses(case_result, "optimize_quantity")

    # Current optimizer responses expose achieved_targets_per_day for successful
    # and infeasible runs. Older runs may still use a separate best-effort
    # wrapper for no-solution responses.
    for response in reversed(optimize_responses):
        payload = _unwrap_result(response)
        status = payload.get("status") if isinstance(payload.get("status"), str) else ""
        values = _extract_metrics_from_payload(payload, wrapper_key=OPTIMIZER_ACHIEVED_WRAPPER_KEY)
        calorie_value = values.get("calories")
        # Treat achieved_targets_per_day as authoritative unless status is
        # explicitly no_solution, which is handled via best_effort below.
        if calorie_value is not None and status.lower() != "no_solution":
            return calorie_value

    # For legacy no-solution runs, trust only best_effort_achieved_targets_per_day.
    for response in reversed(optimize_responses):
        payload = _unwrap_result(response)
        status = payload.get("status") if isinstance(payload.get("status"), str) else ""
        if status.lower() != "no_solution":
            continue

        values = _extract_metrics_from_payload(payload, wrapper_key=BEST_EFFORT_WRAPPER_KEY)
        calorie_value = values.get("calories")
        if calorie_value is not None:
            return calorie_value

    return None


def _extract_metric_score(case_result: dict[str, Any], metric_names: tuple[str, ...]) -> float | None:
    metric_results = case_result.get("overall_eval_metric_results") or []
    if not isinstance(metric_results, list):
        return None

    for metric in metric_results:
        if not isinstance(metric, dict):
            continue
        metric_name = metric.get("metric_name")
        if metric_name not in metric_names:
            continue
        score = _to_float(metric.get("score"))
        if score is not None:
            return score

    return None


def _parse_utc_timestamp(value: str) -> float | None:
    try:
        return datetime.strptime(value, "%Y%m%dT%H%M%S.%fZ").replace(tzinfo=timezone.utc).timestamp()
    except ValueError:
        return None


def _extract_first_user_to_final_response_seconds(trace_payload: dict[str, Any]) -> float | None:
    if trace_payload.get("phase") != "inference":
        return None

    created_at_utc = trace_payload.get("created_at_utc")
    if not isinstance(created_at_utc, str):
        return None
    final_response_ts = _parse_utc_timestamp(created_at_utc)
    if final_response_ts is None:
        return None

    inference_result = trace_payload.get("inference_result")
    if not isinstance(inference_result, dict):
        return None

    inferences = inference_result.get("inferences")
    if not isinstance(inferences, list):
        return None

    start_ts: float | None = None
    for inference in inferences:
        if not isinstance(inference, dict):
            continue
        created_ts = _to_float(inference.get("creation_timestamp"))
        if created_ts is None:
            continue
        if start_ts is None or created_ts < start_ts:
            start_ts = created_ts

    if start_ts is None:
        return None

    return max(0.0, final_response_ts - start_ts)


def _extract_tool_retry_count(trace_payload: dict[str, Any]) -> int | None:
    if trace_payload.get("phase") != "inference":
        return None

    inference_result = trace_payload.get("inference_result")
    if not isinstance(inference_result, dict):
        return None

    inferences = inference_result.get("inferences")
    if not isinstance(inferences, list):
        return None

    optimize_quantity_calls = 0
    for inference in inferences:
        if not isinstance(inference, dict):
            continue
        intermediate_data = inference.get("intermediate_data")
        if not isinstance(intermediate_data, dict):
            continue

        invocation_events = intermediate_data.get("invocation_events")
        if not isinstance(invocation_events, list):
            continue

        for event in invocation_events:
            if not isinstance(event, dict):
                continue
            content = event.get("content")
            if not isinstance(content, dict):
                continue
            parts = content.get("parts")
            if not isinstance(parts, list):
                continue

            for part in parts:
                if not isinstance(part, dict):
                    continue
                function_call = part.get("function_call")
                if not isinstance(function_call, dict):
                    continue
                if function_call.get("name") == "optimize_quantity":
                    optimize_quantity_calls += 1

    if optimize_quantity_calls <= 1:
        return 0
    return optimize_quantity_calls - 1


def _format_number(value: float | None) -> str:
    if value is None:
        return "N/A"
    if value.is_integer():
        return str(int(value))
    return f"{value:.2f}".rstrip("0").rstrip(".")


def _format_percentage_difference(target: float | None, achieved: float | None) -> str:
    if target is None or achieved is None or target == 0:
        return "N/A"

    percent_diff = ((achieved - target) / target) * 100
    formatted = f"{percent_diff:+.2f}".rstrip("0").rstrip(".")
    return f"{formatted}%"


def _result_path_sort_key(result_path: Path) -> tuple[int, str]:
    try:
        stat = result_path.stat()
    except OSError:
        return (-1, result_path.name)
    return (stat.st_mtime_ns, result_path.name)


def _load_first_user_to_final_response_by_eval_id(trace_dir: Path, pattern: str) -> dict[str, float]:
    if not trace_dir.exists() or not trace_dir.is_dir():
        return {}

    latest_latency_by_eval_id: dict[str, tuple[tuple[int, str], float]] = {}

    for trace_path in sorted(trace_dir.glob(pattern), key=_result_path_sort_key):
        try:
            payload = json.loads(trace_path.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError) as exc:
            print(f"Warning: skipping {trace_path.name}: {exc}", file=sys.stderr)
            continue

        eval_id = payload.get("eval_id")
        if not isinstance(eval_id, str) or not eval_id or eval_id.startswith("unknown_eval_id_"):
            inference_result = payload.get("inference_result")
            if isinstance(inference_result, dict):
                fallback_eval_id = inference_result.get("eval_case_id")
                if isinstance(fallback_eval_id, str) and fallback_eval_id:
                    eval_id = fallback_eval_id
        if not isinstance(eval_id, str) or not eval_id:
            continue

        latency = _extract_first_user_to_final_response_seconds(payload)
        if latency is None:
            continue

        latest_latency_by_eval_id[eval_id] = (_result_path_sort_key(trace_path), latency)

    return {eval_id: latency for eval_id, (_, latency) in latest_latency_by_eval_id.items()}


def _load_tool_retry_count_by_eval_id(trace_dir: Path, pattern: str) -> dict[str, int]:
    if not trace_dir.exists() or not trace_dir.is_dir():
        return {}

    latest_retry_count_by_eval_id: dict[str, tuple[tuple[int, str], int]] = {}

    for trace_path in sorted(trace_dir.glob(pattern), key=_result_path_sort_key):
        try:
            payload = json.loads(trace_path.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError) as exc:
            print(f"Warning: skipping {trace_path.name}: {exc}", file=sys.stderr)
            continue

        eval_id = payload.get("eval_id")
        if not isinstance(eval_id, str) or not eval_id or eval_id.startswith("unknown_eval_id_"):
            inference_result = payload.get("inference_result")
            if isinstance(inference_result, dict):
                fallback_eval_id = inference_result.get("eval_case_id")
                if isinstance(fallback_eval_id, str) and fallback_eval_id:
                    eval_id = fallback_eval_id
        if not isinstance(eval_id, str) or not eval_id:
            continue

        retry_count = _extract_tool_retry_count(payload)
        if retry_count is None:
            continue

        latest_retry_count_by_eval_id[eval_id] = (_result_path_sort_key(trace_path), retry_count)

    return {eval_id: retry_count for eval_id, (_, retry_count) in latest_retry_count_by_eval_id.items()}


def _load_rows(history_dir: Path, pattern: str, trace_dir: Path, trace_pattern: str) -> list[dict[str, str]]:
    latest_rows_by_eval_id: dict[str, tuple[tuple[int, str], dict[str, str]]] = {}
    first_user_to_final_response_by_eval_id = _load_first_user_to_final_response_by_eval_id(trace_dir, trace_pattern)
    tool_retry_count_by_eval_id = _load_tool_retry_count_by_eval_id(trace_dir, trace_pattern)

    for result_path in sorted(history_dir.glob(pattern), key=_result_path_sort_key):
        try:
            payload = json.loads(result_path.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError) as exc:
            print(f"Warning: skipping {result_path.name}: {exc}", file=sys.stderr)
            continue

        case_results = payload.get("eval_case_results")
        if not isinstance(case_results, list):
            continue

        result_key = _result_path_sort_key(result_path)
        for case_result in case_results:
            if not isinstance(case_result, dict):
                continue

            eval_id = case_result.get("eval_id")
            if not isinstance(eval_id, str) or not eval_id:
                continue

            target_calorie = _extract_target_calories(case_result)
            achieved_calorie = _extract_achieved_calories(case_result)
            cuisine_alignment_score = _extract_metric_score(case_result, CUISINE_METRIC_NAMES)
            palatability_score = _extract_metric_score(case_result, PALATABILITY_METRIC_NAMES)
            first_user_to_final_response = first_user_to_final_response_by_eval_id.get(eval_id)
            tool_retry_count = tool_retry_count_by_eval_id.get(eval_id)

            latest_rows_by_eval_id[eval_id] = (
                result_key,
                {
                    "eval_id": eval_id,
                    "target_calorie": _format_number(target_calorie),
                    "achieved_calorie": _format_number(achieved_calorie),
                    "calorie_pct_diff": _format_percentage_difference(target_calorie, achieved_calorie),
                    "cuisine_alignment_score": _format_number(cuisine_alignment_score),
                    "palatability_score": _format_number(palatability_score),
                    "first_user_to_final_response_seconds": _format_number(first_user_to_final_response),
                    "tool_retry_count": _format_number(float(tool_retry_count) if tool_retry_count is not None else None),
                },
            )

    sorted_rows = sorted(
        latest_rows_by_eval_id.values(),
        key=lambda item: (item[0][0], item[1]["eval_id"]),
    )
    return [row for _, row in sorted_rows]


def _print_table(rows: list[dict[str, str]]) -> None:
    columns = [
        ("id", "eval_id"),
        ("target_kcal", "target_calorie"),
        ("achvd_kcal", "achieved_calorie"),
        ("kcal_diff_%", "calorie_pct_diff"),
        ("cuisine", "cuisine_alignment_score"),
        ("palat", "palatability_score"),
        ("latency_s", "first_user_to_final_response_seconds"),
        ("retries", "tool_retry_count"),
    ]
    widths = {label: len(label) for label, _ in columns}

    for row in rows:
        for label, key in columns:
            widths[label] = max(widths[label], len(row[key]))

    header_line = "  ".join(label.ljust(widths[label]) for label, _ in columns)
    separator_line = "  ".join("-" * widths[label] for label, _ in columns)
    print(header_line)
    print(separator_line)

    for row in rows:
        print("  ".join(row[key].ljust(widths[label]) for label, key in columns))


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Summarize the latest available data per eval_id from eval history files using target calories from calculate_health_canada_dri, achieved calories from optimize_quantity, and end-to-end latency from evaluation traces."
    )
    parser.add_argument(
        "history_dir",
        nargs="?",
        default=str(DEFAULT_HISTORY_DIR),
        help="Directory containing *.evalset_result.json files.",
    )
    parser.add_argument(
        "--trace-dir",
        default=None,
        help="Directory containing *.trace.json files. Defaults to the sibling traces directory next to the history directory.",
    )
    parser.add_argument(
        "--pattern",
        default="*.evalset_result.json",
        help="Glob pattern to select result files inside the history directory. If multiple files match, the newest result for each eval_id is used.",
    )
    parser.add_argument(
        "--trace-pattern",
        default=INFERENCE_TRACE_PATTERN,
        help="Glob pattern to select trace files inside the trace directory. The newest inference trace for each eval_id is used.",
    )
    return parser


def main() -> int:
    args = build_parser().parse_args()
    history_dir = Path(args.history_dir).expanduser().resolve()
    trace_dir = (
        Path(args.trace_dir).expanduser().resolve()
        if args.trace_dir is not None
        else history_dir.parent / "traces"
    )

    if not history_dir.exists() or not history_dir.is_dir():
        raise SystemExit(f"Eval history directory not found: {history_dir}")

    rows = _load_rows(history_dir, args.pattern, trace_dir, args.trace_pattern)
    if not rows:
        raise SystemExit(f"No eval result files matched in {history_dir}")

    _print_table(rows)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())