#!/usr/bin/env python3
"""Log tool messages for samples with non-zero calorie error or N/A calorie values.

Scans the eval history directory, identifies problematic samples (calorie_pct_diff
!= +0% or target/achieved calorie is N/A), and writes their invocation tool
messages to a timestamped file in the logs/ folder.
"""

from __future__ import annotations

import argparse
import json
import sys
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_HISTORY_DIR = SCRIPT_DIR / ".adk" / "eval_history"
DEFAULT_LOG_DIR = SCRIPT_DIR / "logs"

CALORIE_KEYS = (
    "calories eer_kcal",
    "calories_eer_kcal",
    "eer_kcal",
    "calories",
    "energy_kcal",
    "energy",
)
OPTIMIZER_ACHIEVED_WRAPPER_KEY = "achieved_targets_per_day"
BEST_EFFORT_WRAPPER_KEY = "best_effort_achieved_targets_per_day"
CUISINE_METRIC_NAMES = (
    "per_day_cuisine_alignment_score",
    "cuisine_alignment_score",
)


# ---------------------------------------------------------------------------
# Helpers copied / adapted from summarize_eval_history_calories.py
# ---------------------------------------------------------------------------

def _to_float(value: Any) -> float | None:
    if isinstance(value, (int, float)):
        return float(value)
    if isinstance(value, str):
        try:
            return float(value.replace(",", "").strip())
        except ValueError:
            return None
    return None


def _unwrap_result(payload: dict[str, Any]) -> dict[str, Any]:
    current: dict[str, Any] = payload
    for _ in range(8):
        nested = current.get("result")
        if isinstance(nested, dict):
            current = nested
            continue
        break
    return current


def _iter_invocation_events(case_result: dict[str, Any]) -> list[dict[str, Any]]:
    invocations = case_result.get("eval_metric_result_per_invocation") or []
    if not isinstance(invocations, list):
        return []

    events: list[dict[str, Any]] = []
    for invocation in invocations:
        if not isinstance(invocation, dict):
            continue
        actual_invocation = invocation.get("actual_invocation") or {}
        if not isinstance(actual_invocation, dict):
            continue
        intermediate_data = actual_invocation.get("intermediate_data") or {}
        if not isinstance(intermediate_data, dict):
            continue
        invocation_events = intermediate_data.get("invocation_events")
        if isinstance(invocation_events, list):
            events.extend(e for e in invocation_events if isinstance(e, dict))
    return events


def _extract_tool_messages(case_result: dict[str, Any]) -> list[str]:
    """Return human-readable lines for every tool call and tool response in the case."""
    lines: list[str] = []
    for event in _iter_invocation_events(case_result):
        author = event.get("author", "unknown")
        content = event.get("content") if isinstance(event, dict) else None
        parts = content.get("parts") if isinstance(content, dict) else None
        if not isinstance(parts, list):
            continue

        for part in parts:
            if not isinstance(part, dict):
                continue

            # Tool call
            function_call = part.get("function_call")
            if isinstance(function_call, dict):
                name = function_call.get("name", "unknown_tool")
                args = function_call.get("args", {})
                args_text = json.dumps(args, ensure_ascii=False, default=str)
                lines.append(f"  TOOL_CALL  {name}  args={args_text}")

            # Tool response
            function_response = part.get("function_response")
            if isinstance(function_response, dict):
                name = function_response.get("name", "unknown_tool")
                response = function_response.get("response", {})
                response_text = json.dumps(response, ensure_ascii=False, default=str)
                lines.append(f"  TOOL_RESULT {name}  response={response_text}")

            # Plain text from the agent
            text = part.get("text")
            if isinstance(text, str) and text.strip():
                lines.append(f"  TEXT [{author}]: {text.strip()}")

    return lines


def _extract_function_responses(case_result: dict[str, Any], function_name: str) -> list[dict[str, Any]]:
    responses: list[dict[str, Any]] = []
    for event in _iter_invocation_events(case_result):
        content = event.get("content") if isinstance(event, dict) else None
        parts = content.get("parts") if isinstance(content, dict) else None
        if not isinstance(parts, list):
            continue
        for part in parts:
            if not isinstance(part, dict):
                continue
            function_response = part.get("function_response")
            if not isinstance(function_response, dict):
                continue
            if function_response.get("name") != function_name:
                continue
            payload = function_response.get("response")
            if isinstance(payload, dict):
                responses.append(payload)
    return responses


def _extract_target_calories(case_result: dict[str, Any]) -> float | None:
    dri_responses = _extract_function_responses(case_result, "calculate_health_canada_dri")
    if not dri_responses:
        return None
    for response in reversed(dri_responses):
        dri_payload = _unwrap_result(response)
        calories = dri_payload.get("calories")
        if isinstance(calories, dict):
            eer = _to_float(calories.get("eer_kcal"))
            if eer is not None:
                return eer
        else:
            calorie_value = _to_float(calories)
            if calorie_value is not None:
                return calorie_value
        recommended_per_day = dri_payload.get("recommended_g_per_day")
        if isinstance(recommended_per_day, dict):
            for key in CALORIE_KEYS:
                calorie_value = _to_float(recommended_per_day.get(key))
                if calorie_value is not None:
                    return calorie_value
    return None


def _extract_metric_values(payload: dict[str, Any]) -> dict[str, float]:
    values: dict[str, float] = {}
    for key in ("calories", "total_fibre", "protein", "carbohydrates", "total_fat"):
        value = _to_float(payload.get(key))
        if value is not None:
            values[key] = value
    return values


def _extract_metrics_from_payload(payload: dict[str, Any], *, wrapper_key: str | None) -> dict[str, float]:
    payload = _unwrap_result(payload)
    if wrapper_key is not None:
        nested = payload.get(wrapper_key)
        if isinstance(nested, dict):
            values = _extract_metric_values(nested)
            if values:
                return values
    return {}


def _extract_achieved_calories(case_result: dict[str, Any]) -> float | None:
    optimize_responses = _extract_function_responses(case_result, "optimize_quantity")
    for response in reversed(optimize_responses):
        payload = _unwrap_result(response)
        status = payload.get("status") if isinstance(payload.get("status"), str) else ""
        values = _extract_metrics_from_payload(payload, wrapper_key=OPTIMIZER_ACHIEVED_WRAPPER_KEY)
        calorie_value = values.get("calories")
        if calorie_value is not None and status.lower() != "no_solution":
            return calorie_value
    for response in reversed(optimize_responses):
        payload = _unwrap_result(response)
        status = payload.get("status") if isinstance(payload.get("status"), str) else ""
        if status.lower() != "no_solution":
            continue
        values = _extract_metrics_from_payload(payload, wrapper_key=BEST_EFFORT_WRAPPER_KEY)
        calorie_value = values.get("calories")
        if calorie_value is not None:
            return calorie_value
    return None


def _format_number(value: float | None) -> str:
    if value is None:
        return "N/A"
    if value.is_integer():
        return str(int(value))
    return f"{value:.2f}".rstrip("0").rstrip(".")


def _format_pct_diff(target: float | None, achieved: float | None) -> str:
    if target is None or achieved is None or target == 0:
        return "N/A"
    pct = ((achieved - target) / target) * 100
    formatted = f"{pct:+.2f}".rstrip("0").rstrip(".")
    return f"{formatted}%"


def _result_path_sort_key(result_path: Path) -> tuple[int, str]:
    try:
        stat = result_path.stat()
    except OSError:
        return (-1, result_path.name)
    return (stat.st_mtime_ns, result_path.name)


# ---------------------------------------------------------------------------
# Core logic
# ---------------------------------------------------------------------------

def _is_problematic(target: float | None, achieved: float | None) -> bool:
    """Return True if calorie is N/A or the percentage error is non-zero."""
    if target is None or achieved is None:
        return True
    if target == 0:
        return True
    pct = abs((achieved - target) / target) * 100
    return pct > 0.0


def _collect_problematic(history_dir: Path, pattern: str) -> list[dict[str, Any]]:
    """Return a list of dicts describing each problematic case, newest result wins per eval_id."""
    latest_by_eval_id: dict[str, tuple[tuple[int, str], dict[str, Any]]] = {}

    for result_path in sorted(history_dir.glob(pattern), key=_result_path_sort_key):
        try:
            payload = json.loads(result_path.read_text(encoding="utf-8"))
        except (OSError, json.JSONDecodeError) as exc:
            print(f"Warning: skipping {result_path.name}: {exc}", file=sys.stderr)
            continue

        case_results = payload.get("eval_case_results")
        if not isinstance(case_results, list):
            continue

        result_key = _result_path_sort_key(result_path)
        for case_result in case_results:
            if not isinstance(case_result, dict):
                continue
            eval_id = case_result.get("eval_id")
            if not isinstance(eval_id, str) or not eval_id:
                continue

            target = _extract_target_calories(case_result)
            achieved = _extract_achieved_calories(case_result)

            if not _is_problematic(target, achieved):
                continue

            tool_messages = _extract_tool_messages(case_result)

            record: dict[str, Any] = {
                "eval_id": eval_id,
                "target_calorie": _format_number(target),
                "achieved_calorie": _format_number(achieved),
                "calorie_pct_diff": _format_pct_diff(target, achieved),
                "source_file": result_path.name,
                "tool_messages": tool_messages,
            }
            latest_by_eval_id[eval_id] = (result_key, record)

    sorted_records = sorted(
        latest_by_eval_id.values(),
        key=lambda item: (item[0][0], item[1]["eval_id"]),
    )
    return [rec for _, rec in sorted_records]


def _write_log(records: list[dict[str, Any]], log_dir: Path) -> Path:
    log_dir.mkdir(parents=True, exist_ok=True)
    ts = datetime.now(tz=timezone.utc).strftime("%Y%m%dT%H%M%SZ")
    log_path = log_dir / f"problematic_samples_{ts}.txt"

    lines: list[str] = [
        f"Problematic Samples Log — {ts}",
        f"Total problematic samples: {len(records)}",
        "=" * 72,
        "",
    ]

    for rec in records:
        lines.append(f"eval_id          : {rec['eval_id']}")
        lines.append(f"target_calorie   : {rec['target_calorie']}")
        lines.append(f"achieved_calorie : {rec['achieved_calorie']}")
        lines.append(f"calorie_pct_diff : {rec['calorie_pct_diff']}")
        lines.append(f"source_file      : {rec['source_file']}")
        lines.append("tool_messages:")
        if rec["tool_messages"]:
            lines.extend(rec["tool_messages"])
        else:
            lines.append("  (none recorded)")
        lines.append("-" * 72)
        lines.append("")

    log_path.write_text("\n".join(lines), encoding="utf-8")
    return log_path


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description=(
            "Log tool messages for eval samples with non-zero calorie error "
            "or N/A calorie values. Output is saved to the logs/ folder."
        )
    )
    parser.add_argument(
        "history_dir",
        nargs="?",
        default=str(DEFAULT_HISTORY_DIR),
        help="Directory containing *.evalset_result.json files.",
    )
    parser.add_argument(
        "--pattern",
        default="*.evalset_result.json",
        help="Glob pattern to select result files (default: *.evalset_result.json).",
    )
    parser.add_argument(
        "--log-dir",
        default=str(DEFAULT_LOG_DIR),
        help="Directory to write log files (default: <script_dir>/logs).",
    )
    return parser


def main() -> int:
    args = build_parser().parse_args()
    history_dir = Path(args.history_dir).expanduser().resolve()
    log_dir = Path(args.log_dir).expanduser().resolve()

    if not history_dir.exists() or not history_dir.is_dir():
        raise SystemExit(f"Eval history directory not found: {history_dir}")

    records = _collect_problematic(history_dir, args.pattern)

    if not records:
        print("No problematic samples found.")
        return 0

    log_path = _write_log(records, log_dir)
    print(f"Found {len(records)} problematic sample(s).")
    print(f"Log written to: {log_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
