#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any


SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_HISTORY_DIR = SCRIPT_DIR / ".adk" / "eval_history"
CALORIE_KEYS = (
    "calories eer_kcal",
    "calories_eer_kcal",
    "eer_kcal",
    "calories",
    "energy_kcal",
    "energy",
)
OPTIMIZER_ACHIEVED_WRAPPER_KEY = "achieved_targets_per_day"
BEST_EFFORT_WRAPPER_KEY = "best_effort_achieved_targets_per_day"
CUISINE_METRIC_NAMES = (
    "per_day_cuisine_alignment_score",
    "cuisine_alignment_score",
)
PALATABILITY_METRIC_NAMES = (
    "per_day_palatability_score",
    "palatability_score",
)
INFERENCE_TRACE_PATTERN = "**/*.inference.trace.json"


def _to_float(value: Any) -> float | None:
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        try:
            return float(value.replace(",", "").strip())
        except ValueError:
            return None
    return None


def _iter_invocation_events(case_result: dict[str, Any]) -> list[dict[str, Any]]:
    invocations = case_result.get("eval_metric_result_per_invocation") or []
    if not isinstance(invocations, list):
        return []

    events: list[dict[str, Any]] = []
    for invocation in invocations:
        if not isinstance(invocation, dict):
            continue
        actual_invocation = invocation.get("actual_invocation") or {}
        if not isinstance(actual_invocation, dict):
            continue

        intermediate_data = actual_invocation.get("intermediate_data") or {}
        if not isinstance(intermediate_data, dict):
            continue

        invocation_events = intermediate_data.get("invocation_events")
        if isinstance(invocation_events, list):
            events.extend(event for event in invocation_events if isinstance(event, dict))

    return events


def _unwrap_result(payload: dict[str, Any]) -> dict[str, Any]:
    current: dict[str, Any] = payload
    # Some ADK payloads nest result objects multiple levels deep.
    for _ in range(8):
        nested = current.get("result")
        if isinstance(nested, dict):
            current = nested
            continue
        break
    return current


def _extract_function_responses(case_result: dict[str, Any], function_name: str) -> list[dict[str, Any]]:
    responses: list[dict[str, Any]] = []
    for event in _iter_invocation_events(case_result):
        content = event.get("content") if isinstance(event, dict) else None
        parts = content.get("parts") if isinstance(content, dict) else None
        if not isinstance(parts, list):
            continue

        for part in parts:
            if not isinstance(part, dict):
                continue
            function_response = part.get("function_response")
            if not isinstance(function_response, dict):
                continue
            if function_response.get("name") != function_name:
                continue

            payload = function_response.get("response")
            if isinstance(payload, dict):
                responses.append(payload)
    return responses


def _extract_target_calories(case_result: dict[str, Any]) -> float | None:
    dri_responses = _extract_function_responses(case_result, "calculate_health_canada_dri")
    if not dri_responses:
        return None

    for response in reversed(dri_responses):
        dri_payload = _unwrap_result(response)

        calories = dri_payload.get("calories")
        if isinstance(calories, dict):
            eer = _to_float(calories.get("eer_kcal"))
            if eer is not None:
                return eer
        else:
            calorie_value = _to_float(calories)
            if calorie_value is not None:
                return calorie_value

        recommended_per_day = dri_payload.get("recommended_g_per_day")
        if isinstance(recommended_per_day, dict):
            for key in CALORIE_KEYS:
                calorie_value = _to_float(recommended_per_day.get(key))
                if calorie_value is not None:
                    return calorie_value

    return None


def _extract_metric_values(payload: dict[str, Any]) -> dict[str, float]:
    values: dict[str, float] = {}
    for key in ("calories", "total_fibre", "protein", "carbohydrates", "total_fat"):
        value = _to_float(payload.get(key))
        if value is not None:
            values[key] = value
    return values


def _extract_metrics_from_payload(payload: dict[str, Any], *, wrapper_key: str | None) -> dict[str, float]:
    payload = _unwrap_result(payload)

    if wrapper_key is not None:
        nested = payload.get(wrapper_key)
        if isinstance(nested, dict):
            values = _extract_metric_values(nested)
            if values:
                return values

    return {}


def _extract_achieved_calories_all_calls(case_result: dict[str, Any]) -> list[float | None]:
    """Extract achieved calories from all calls to calculate_average_macro_nutrient_per_day."""
    avg_macro_responses = _extract_function_responses(case_result, "calculate_average_macro_nutrient_per_day")
    
    achieved_calories: list[float | None] = []
    for response in avg_macro_responses:
        payload = _unwrap_result(response)
        avg_macros = payload.get("average_macro_nutrient_from_calculated_quantity_per_day")
        if isinstance(avg_macros, dict):
            calorie_value = _to_float(avg_macros.get("calories"))
            achieved_calories.append(calorie_value)
        else:
            achieved_calories.append(None)
    
    return achieved_calories


def _extract_achieved_calories(case_result: dict[str, Any]) -> float | None:
    """Extract achieved calories from the latest call (for backward compatibility)."""
    achieved_list = _extract_achieved_calories_all_calls(case_result)
    return achieved_list[-1] if achieved_list else None


def _extract_metric_score(case_result: dict[str, Any], metric_names: tuple[str, ...]) -> float | None:
    metric_results = case_result.get("overall_eval_metric_results") or []
    if not isinstance(metric_results, list):
        return None

    for metric in metric_results:
        if not isinstance(metric, dict):
            continue
        metric_name = metric.get("metric_name")
        if metric_name not in metric_names:
            continue
        score = _to_float(metric.get("score"))
        if score is not None:
            return score

    return None


def _parse_utc_timestamp(value: str) -> float | None:
    try:
        return datetime.strptime(value, "%Y%m%dT%H%M%S.%fZ").replace(tzinfo=timezone.utc).timestamp()
    except ValueError:
        return None


def _extract_first_user_to_final_response_seconds(trace_payload: dict[str, Any]) -> float | None:
    if trace_payload.get("phase") != "inference":
        return None

    created_at_utc = trace_payload.get("created_at_utc")
    if not isinstance(created_at_utc, str):
        return None
    final_response_ts = _parse_utc_timestamp(created_at_utc)
    if final_response_ts is None:
        return None

    inference_result = trace_payload.get("inference_result")
    if not isinstance(inference_result, dict):
        return None

    inferences = inference_result.get("inferences")
    if not isinstance(inferences, list):
        return None

    start_ts: float | None = None
    for inference in inferences:
        if not isinstance(inference, dict):
            continue
        created_ts = _to_float(inference.get("creation_timestamp"))
        if created_ts is None:
            continue
        if start_ts is None or created_ts < start_ts:
            start_ts = created_ts

    if start_ts is None:
        return None

    return max(0.0, final_response_ts - start_ts)


def _format_number(value: float | None) -> str:
    if value is None:
        return "N/A"
    if value.is_integer():
        return str(int(value))
    return f"{value:.2f}".rstrip("0").rstrip(".")


def _format_percentage_difference(target: float | None, achieved: float | None) -> str:
    if target is None or achieved is None or target == 0:
        return "N/A"

    percent_diff = ((achieved - target) / target) * 100
    formatted = f"{percent_diff:+.2f}".rstrip("0").rstrip(".")
    return f"{formatted}%"


def _result_path_sort_key(result_path: Path) -> tuple[int, str]:
    try:
        stat = result_path.stat()
    except OSError:
        return (-1, result_path.name)
    return (stat.st_mtime_ns, result_path.name)


def _load_first_user_to_final_response_by_eval_id(trace_dir: Path, pattern: str) -> dict[str, float]:
    if not trace_dir.exists() or not trace_dir.is_dir():
        return {}

    latest_latency_by_eval_id: dict[str, tuple[tuple[int, str], float]] = {}

    for trace_path in sorted(trace_dir.glob(pattern), key=_result_path_sort_key):
        try:
            payload = json.loads(trace_path.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError) as exc:
            print(f"Warning: skipping {trace_path.name}: {exc}", file=sys.stderr)
            continue

        eval_id = payload.get("eval_id")
        if not isinstance(eval_id, str) or not eval_id or eval_id.startswith("unknown_eval_id_"):
            inference_result = payload.get("inference_result")
            if isinstance(inference_result, dict):
                fallback_eval_id = inference_result.get("eval_case_id")
                if isinstance(fallback_eval_id, str) and fallback_eval_id:
                    eval_id = fallback_eval_id
        if not isinstance(eval_id, str) or not eval_id:
            continue

        latency = _extract_first_user_to_final_response_seconds(payload)
        if latency is None:
            continue

        latest_latency_by_eval_id[eval_id] = (_result_path_sort_key(trace_path), latency)

    return {eval_id: latency for eval_id, (_, latency) in latest_latency_by_eval_id.items()}


def _load_rows_by_eval_set(
    history_dir: Path,
    pattern: str,
    trace_dir: Path,
    trace_pattern: str,
) -> dict[str, list[dict[str, str]]]:
    latest_rows_by_eval_set_and_id: dict[tuple[str, str], tuple[tuple[int, str], dict[str, str]]] = {}
    first_user_to_final_response_by_eval_id = _load_first_user_to_final_response_by_eval_id(trace_dir, trace_pattern)

    for result_path in sorted(history_dir.glob(pattern), key=_result_path_sort_key):
        try:
            payload = json.loads(result_path.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError) as exc:
            print(f"Warning: skipping {result_path.name}: {exc}", file=sys.stderr)
            continue

        case_results = payload.get("eval_case_results")
        if not isinstance(case_results, list):
            continue

        result_key = _result_path_sort_key(result_path)
        for case_result in case_results:
            if not isinstance(case_result, dict):
                continue

            eval_set_id = case_result.get("eval_set_id") or payload.get("eval_set_id")
            if not isinstance(eval_set_id, str) or not eval_set_id:
                eval_set_id = "unknown_eval_set"

            eval_id = case_result.get("eval_id")
            if not isinstance(eval_id, str) or not eval_id:
                continue

            target_calorie = _extract_target_calories(case_result)
            achieved_calories_list = _extract_achieved_calories_all_calls(case_result)
            cuisine_alignment_score = _extract_metric_score(case_result, CUISINE_METRIC_NAMES)
            palatability_score = _extract_metric_score(case_result, PALATABILITY_METRIC_NAMES)
            first_user_to_final_response = first_user_to_final_response_by_eval_id.get(eval_id)

            row_dict = {
                "eval_id": eval_id,
                "target_calorie": _format_number(target_calorie),
                "cuisine_alignment_score": _format_number(cuisine_alignment_score),
                "palatability_score": _format_number(palatability_score),
                "first_user_to_final_response_seconds": _format_number(first_user_to_final_response),
            }
            
            # Add columns for first 3 calls
            for call_num in range(1, 4):
                achieved_calorie = achieved_calories_list[call_num - 1] if call_num <= len(achieved_calories_list) else None
                row_dict[f"call_{call_num}_achieved"] = _format_number(achieved_calorie)
                row_dict[f"call_{call_num}_pct_diff"] = _format_percentage_difference(target_calorie, achieved_calorie)

            latest_rows_by_eval_set_and_id[(eval_set_id, eval_id)] = (result_key, row_dict)

    sorted_rows = sorted(
        latest_rows_by_eval_set_and_id.items(),
        key=lambda item: (item[0][0], item[1][0][0], item[0][1]),
    )

    rows_by_eval_set: dict[str, list[dict[str, str]]] = {}
    for (eval_set_id, _eval_id), (_result_key, row) in sorted_rows:
        rows_by_eval_set.setdefault(eval_set_id, []).append(row)
    return rows_by_eval_set


def _print_table(rows: list[dict[str, str]]) -> None:
    if not rows:
        return

    # Display compact labels while reading values from stable row keys.
    headers = [
        ("eval_id", "id"),
        ("target_calorie", "target_kcal"),
        ("cuisine_alignment_score", "cuisine"),
        ("palatability_score", "palat"),
        ("first_user_to_final_response_seconds", "time_s"),
    ]
    for call_num in range(1, 4):
        headers.append((f"call_{call_num}_achieved", f"c{call_num}_kcal"))
        headers.append((f"call_{call_num}_pct_diff", f"c{call_num}_pct"))

    widths = {key: len(label) for key, label in headers}

    for row in rows:
        for key, _label in headers:
            value = row.get(key, "N/A")
            widths[key] = max(widths[key], len(value))

    header_line = "  ".join(label.ljust(widths[key]) for key, label in headers)
    separator_line = "  ".join("-" * widths[key] for key, _label in headers)
    print(header_line)
    print(separator_line)

    for row in rows:
        values = [row.get(key, "N/A").ljust(widths[key]) for key, _label in headers]
        print("  ".join(values))


def _print_grouped_tables(rows_by_eval_set: dict[str, list[dict[str, str]]]) -> None:
    if not rows_by_eval_set:
        return

    for idx, eval_set_id in enumerate(sorted(rows_by_eval_set)):
        rows = rows_by_eval_set[eval_set_id]
        if not rows:
            continue

        if idx > 0:
            print()
        print(f"eval_set: {eval_set_id}")
        _print_table(rows)


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Summarize latest available calorie data from eval history files. Results are shown separately for each eval_set_id."
    )
    parser.add_argument(
        "history_dir",
        nargs="?",
        default=str(DEFAULT_HISTORY_DIR),
        help="Directory containing *.evalset_result.json files.",
    )
    parser.add_argument(
        "--trace-dir",
        default=None,
        help="Directory containing inference trace files. Defaults to the sibling traces directory next to the history directory.",
    )
    parser.add_argument(
        "--pattern",
        default="*.evalset_result.json",
        help="Glob pattern to select result files inside the history directory. If multiple files match, the newest result for each (eval_set_id, eval_id) pair is used.",
    )
    parser.add_argument(
        "--trace-pattern",
        default=INFERENCE_TRACE_PATTERN,
        help="Glob pattern to select inference trace files inside the trace directory (recursive by default). The newest inference trace for each eval_id is used.",
    )
    return parser


def main() -> int:
    args = build_parser().parse_args()
    history_dir = Path(args.history_dir).expanduser().resolve()
    trace_dir = (
        Path(args.trace_dir).expanduser().resolve()
        if args.trace_dir is not None
        else history_dir.parent / "traces"
    )

    if not history_dir.exists() or not history_dir.is_dir():
        raise SystemExit(f"Eval history directory not found: {history_dir}")

    rows_by_eval_set = _load_rows_by_eval_set(
        history_dir,
        args.pattern,
        trace_dir,
        args.trace_pattern,
    )
    if not rows_by_eval_set:
        raise SystemExit(f"No eval result files matched in {history_dir}")

    _print_grouped_tables(rows_by_eval_set)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
