#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable


SCRIPT_DIR = Path(__file__).resolve().parent
CODE_DIR = SCRIPT_DIR.parent
NUT_DIR = CODE_DIR.parent

CANADA_HISTORY_DIR = CODE_DIR / "CanadaDRIPlanner" / ".adk" / "eval_history"
LLM_HISTORY_DIR = CODE_DIR / "LLMNutritionPlanner" / ".adk" / "eval_history"
MFP_HISTORY_DIR = CODE_DIR / "MFPNutritionPlanner" / ".adk" / "eval_history"
PAPER_TEX_PATH = NUT_DIR / "acl" / "latex" / "ano_acl_long.tex"

DRI_EVAL_SET_ID = "dri"
MFP_MACRO_EVAL_SET_ID = "category_samples_mfp_macro"

CUISINE_METRIC_NAMES = (
	"per_day_cuisine_alignment_score",
	"cuisine_alignment_score",
)


@dataclass
class CaseMetrics:
	calorie_deviation: float | None
	protein_deviation: float | None
	carb_deviation: float | None
	fat_deviation: float | None
	cuisine_score: float | None
	fibre_deviation: float | None


def _to_float(value: Any) -> float | None:
	if isinstance(value, (int, float)):
		return float(value)
	if isinstance(value, str):
		cleaned = value.replace(",", "").strip()
		try:
			return float(cleaned)
		except ValueError:
			return None
	return None


def _normalize_key(text: str) -> str:
	return re.sub(r"[^a-z0-9]", "", str(text).lower())


def _find_number_by_keys(source: dict[str, Any], candidate_keys: Iterable[str]) -> float | None:
	normalized_to_value: dict[str, Any] = {
		_normalize_key(key): value for key, value in source.items()
	}
	for key in candidate_keys:
		value = normalized_to_value.get(_normalize_key(key))
		number = _to_float(value)
		if number is not None:
			return number
	return None


def _iter_invocation_events(case_result: dict[str, Any]) -> Iterable[dict[str, Any]]:
	per_invocation = case_result.get("eval_metric_result_per_invocation")
	if not isinstance(per_invocation, list):
		return
	for item in per_invocation:
		if not isinstance(item, dict):
			continue
		invocation = item.get("actual_invocation")
		if not isinstance(invocation, dict):
			continue
		intermediate_data = invocation.get("intermediate_data")
		if not isinstance(intermediate_data, dict):
			continue
		events = intermediate_data.get("invocation_events")
		if not isinstance(events, list):
			continue
		for event in events:
			if isinstance(event, dict):
				yield event


def _unwrap_result(payload: dict[str, Any]) -> dict[str, Any]:
	current: dict[str, Any] = payload
	for _ in range(10):
		nested = current.get("result")
		if isinstance(nested, dict):
			current = nested
			continue
		break
	return current


def _extract_function_responses(case_result: dict[str, Any], function_name: str) -> list[dict[str, Any]]:
	outputs: list[dict[str, Any]] = []
	for event in _iter_invocation_events(case_result):
		content = event.get("content")
		if not isinstance(content, dict):
			continue
		parts = content.get("parts")
		if not isinstance(parts, list):
			continue
		for part in parts:
			if not isinstance(part, dict):
				continue
			function_response = part.get("function_response")
			if not isinstance(function_response, dict):
				continue
			if function_response.get("name") != function_name:
				continue
			payload = function_response.get("response")
			if isinstance(payload, dict):
				outputs.append(payload)
	return outputs


def _extract_latest_function_response(case_result: dict[str, Any], function_name: str) -> dict[str, Any] | None:
	responses = _extract_function_responses(case_result, function_name)
	if not responses:
		return None
	return _unwrap_result(responses[-1])


def _extract_cuisine_score(case_result: dict[str, Any]) -> float | None:
	metric_results = case_result.get("overall_eval_metric_results")
	if not isinstance(metric_results, list):
		return None
	for metric in metric_results:
		if not isinstance(metric, dict):
			continue
		if metric.get("metric_name") not in CUISINE_METRIC_NAMES:
			continue
		value = _to_float(metric.get("score"))
		if value is not None:
			return value
	return None


def _extract_dri_targets(case_result: dict[str, Any]) -> dict[str, Any] | None:
	payload = _extract_latest_function_response(case_result, "calculate_health_canada_dri")
	if not payload:
		return None

	recommended_per_day = payload.get("recommended_g_per_day")
	macro_ranges = payload.get("recommended_macronutrient_ranges_g_per_day")
	if not isinstance(recommended_per_day, dict) or not isinstance(macro_ranges, dict):
		return None

	calorie_target = _find_number_by_keys(
		recommended_per_day,
		("calories eer_kcal", "calories_eer_kcal", "eer_kcal", "calories", "energy_kcal", "energy"),
	)
	fibre_target = _find_number_by_keys(
		recommended_per_day,
		("total fibre", "total_fibre", "fibre", "fiber", "dietary_fiber"),
	)

	protein_range = None
	carb_range = None
	fat_range = None
	for key, bounds in macro_ranges.items():
		if not isinstance(bounds, dict):
			continue
		lower = _to_float(bounds.get("lower"))
		upper = _to_float(bounds.get("upper"))
		if lower is None or upper is None:
			continue
		norm = _normalize_key(key)
		if "protein" in norm:
			protein_range = (lower, upper)
		elif "carb" in norm:
			carb_range = (lower, upper)
		elif "fat" in norm:
			fat_range = (lower, upper)

	if calorie_target is None or fibre_target is None:
		return None

	if protein_range is None or carb_range is None or fat_range is None:
		return None

	return {
		"calories": calorie_target,
		"fibre": fibre_target,
		"protein_range": protein_range,
		"carb_range": carb_range,
		"fat_range": fat_range,
	}


def _extract_mfp_targets(case_result: dict[str, Any]) -> dict[str, Any] | None:
	payload = _extract_latest_function_response(case_result, "calculate_mfp_macros")
	if not payload:
		return None

	calories = _to_float(payload.get("target_calories_kcal"))
	macros = payload.get("macros_g_per_day")
	if calories is None or not isinstance(macros, dict):
		return None

	protein = _to_float(macros.get("protein"))
	carbs = _to_float(macros.get("carbohydrates"))
	fat = _to_float(macros.get("fat"))
	if protein is None or carbs is None or fat is None:
		return None

	return {
		"calories": calories,
		"protein": protein,
		"carbohydrates": carbs,
		"total_fat": fat,
	}


def _extract_achieved_from_optimizer(case_result: dict[str, Any], function_name: str) -> dict[str, float] | None:
	responses = _extract_function_responses(case_result, function_name)
	for payload in reversed(responses):
		current = _unwrap_result(payload)

		for wrapper_key in ("achieved_targets_per_day", "best_effort_achieved_targets_per_day"):
			wrapper = current.get(wrapper_key)
			if not isinstance(wrapper, dict):
				continue
			values = {
				"calories": _to_float(wrapper.get("calories")),
				"protein": _to_float(wrapper.get("protein")),
				"carbohydrates": _to_float(wrapper.get("carbohydrates")),
				"total_fat": _to_float(wrapper.get("total_fat")),
				"total_fibre": _to_float(wrapper.get("total_fibre")),
			}
			if values["calories"] is not None:
				return {k: v for k, v in values.items() if v is not None}
	return None


def _has_strict_optimizer_feasible_result(case_result: dict[str, Any], function_name: str) -> bool:
	responses = _extract_function_responses(case_result, function_name)
	for payload in reversed(responses):
		current = _unwrap_result(payload)
		wrapper = current.get("achieved_targets_per_day")
		if not isinstance(wrapper, dict):
			continue
		if _to_float(wrapper.get("calories")) is not None:
			return True
	return False


def _extract_achieved_from_average_tool(case_result: dict[str, Any]) -> dict[str, float] | None:
	payload = _extract_latest_function_response(case_result, "calculate_average_macro_nutrient_per_day")
	if not payload:
		return None
	wrapper = payload.get("average_macro_nutrient_from_calculated_quantity_per_day")
	if not isinstance(wrapper, dict):
		return None
	values = {
		"calories": _to_float(wrapper.get("calories")),
		"protein": _to_float(wrapper.get("protein")),
		"carbohydrates": _to_float(wrapper.get("carbohydrates")),
		"total_fat": _to_float(wrapper.get("total_fat")),
		"total_fibre": _to_float(wrapper.get("total_fibre")),
	}
	if values["calories"] is None:
		return None
	return {k: v for k, v in values.items() if v is not None}


def _outside_range_deviation(value: float | None, bounds: tuple[float, float] | None) -> float | None:
	if value is None or bounds is None:
		return None
	lower, upper = bounds
	if value < lower:
		return lower - value
	if value > upper:
		return value - upper
	return 0.0


def _is_llm_feasible(metrics: CaseMetrics) -> bool:
	"""LLM feasibility is defined by fixed absolute deviation thresholds."""
	return (
		metrics.calorie_deviation is not None
		and metrics.calorie_deviation <= 10.0
		and metrics.protein_deviation is not None
		and metrics.protein_deviation <= 5.0
		and metrics.carb_deviation is not None
		and metrics.carb_deviation <= 2.0
		and metrics.fat_deviation is not None
		and metrics.fat_deviation <= 5.0
		and metrics.fibre_deviation is not None
		and metrics.fibre_deviation <= 5.0
	)


def _extract_case_metrics(
	case_result: dict[str, Any],
	*,
	target_mode: str,
	achieved_mode: str,
	allow_dri_fallback_for_mfp: bool,
) -> CaseMetrics | None:
	cuisine_score = _extract_cuisine_score(case_result)

	dri_targets = _extract_dri_targets(case_result)
	mfp_targets = _extract_mfp_targets(case_result)

	achieved: dict[str, float] | None = None
	if achieved_mode == "optimizer_dri":
		achieved = _extract_achieved_from_optimizer(case_result, "optimize_quantity")
	elif achieved_mode == "optimizer_mfp":
		achieved = _extract_achieved_from_optimizer(case_result, "optimize_quantity_for_mfp_targets")
	elif achieved_mode == "average_tool":
		achieved = _extract_achieved_from_average_tool(case_result)

	if achieved is None:
		return None

	calories_achieved = achieved.get("calories")
	protein_achieved = achieved.get("protein")
	carbs_achieved = achieved.get("carbohydrates")
	fat_achieved = achieved.get("total_fat")
	fibre_achieved = achieved.get("total_fibre")

	if target_mode == "dri":
		if not dri_targets:
			return None
		calorie_target = dri_targets["calories"]
		fibre_target = dri_targets["fibre"]
		protein_range = dri_targets["protein_range"]
		carb_range = dri_targets["carb_range"]
		fat_range = dri_targets["fat_range"]
		return CaseMetrics(
			calorie_deviation=abs(calories_achieved - calorie_target) if calories_achieved is not None else None,
			protein_deviation=_outside_range_deviation(protein_achieved, protein_range),
			carb_deviation=_outside_range_deviation(carbs_achieved, carb_range),
			fat_deviation=_outside_range_deviation(fat_achieved, fat_range),
			cuisine_score=cuisine_score,
			fibre_deviation=abs(fibre_achieved - fibre_target)
			if fibre_achieved is not None
			else None,
		)

	if target_mode == "mfp":
		effective_mfp = mfp_targets
		if effective_mfp is None and allow_dri_fallback_for_mfp and dri_targets is not None:
			# Fallback keeps the table populated for traces where the LLM-only
			# agent did not call calculate_mfp_macros.
			effective_mfp = {
				"calories": dri_targets["calories"],
				"protein": dri_targets["protein_range"][0],
				"carbohydrates": dri_targets["carb_range"][0],
				"total_fat": dri_targets["fat_range"][0],
			}
		if not effective_mfp:
			return None

		calorie_target = effective_mfp["calories"]
		protein_target = effective_mfp["protein"]
		carbs_target = effective_mfp["carbohydrates"]
		fat_target = effective_mfp["total_fat"]

		fibre_target = dri_targets["fibre"] if dri_targets is not None else None

		return CaseMetrics(
			calorie_deviation=abs(calories_achieved - calorie_target) if calories_achieved is not None else None,
			protein_deviation=abs(protein_achieved - protein_target) if protein_achieved is not None else None,
			carb_deviation=abs(carbs_achieved - carbs_target) if carbs_achieved is not None else None,
			fat_deviation=abs(fat_achieved - fat_target) if fat_achieved is not None else None,
			cuisine_score=cuisine_score,
			fibre_deviation=(
				abs(fibre_achieved - fibre_target)
				if fibre_achieved is not None and fibre_target is not None
				else None
			),
		)

	return None


def _result_path_sort_key(path: Path) -> tuple[int, str]:
	try:
		stat = path.stat()
	except OSError:
		return (-1, path.name)
	return (stat.st_mtime_ns, path.name)


def _load_latest_case_results(history_dir: Path, eval_set_id: str) -> dict[str, dict[str, Any]]:
	latest_by_eval_id: dict[str, tuple[tuple[int, str], dict[str, Any]]] = {}
	for path in sorted(history_dir.glob("*.evalset_result.json"), key=_result_path_sort_key):
		try:
			payload = json.loads(path.read_text(encoding="utf-8"))
		except (OSError, json.JSONDecodeError):
			continue

		if payload.get("eval_set_id") != eval_set_id:
			continue

		case_results = payload.get("eval_case_results")
		if not isinstance(case_results, list):
			continue

		sort_key = _result_path_sort_key(path)
		for case_result in case_results:
			if not isinstance(case_result, dict):
				continue
			eval_id = case_result.get("eval_id")
			if not isinstance(eval_id, str) or not eval_id:
				continue
			latest_by_eval_id[eval_id] = (sort_key, case_result)

	return {eval_id: case for eval_id, (_k, case) in latest_by_eval_id.items()}


def _bucket_distribution(values: list[float], upper_bounds: tuple[float, float, float]) -> list[float]:
	if not values:
		return [0.0, 0.0, 0.0, 0.0]

	b1, b2, b3 = upper_bounds
	counts = [0, 0, 0, 0]
	for value in values:
		if value <= b1:
			counts[0] += 1
		elif value <= b2:
			counts[1] += 1
		elif value <= b3:
			counts[2] += 1
		else:
			counts[3] += 1

	total = float(len(values))
	return [(count / total) * 100.0 for count in counts]


def _cuisine_distribution(scores: list[float]) -> list[float]:
	if not scores:
		return [0.0, 0.0, 0.0, 0.0]
	counts = [0, 0, 0, 0]
	for score in scores:
		if score < 0.25:
			counts[0] += 1
		elif score < 0.50:
			counts[1] += 1
		elif score < 0.75:
			counts[2] += 1
		else:
			counts[3] += 1
	total = float(len(scores))
	return [(count / total) * 100.0 for count in counts]


def _format_pct(value: float | None) -> str:
	if value is None:
		return "N/A"
	return f"{value:.1f}\\%"


def _distribution_strings(values: list[float], upper_bounds: tuple[float, float, float]) -> list[str]:
	if not values:
		return ["N/A", "N/A", "N/A", "N/A"]
	return [_format_pct(x) for x in _bucket_distribution(values, upper_bounds)]


def _cuisine_strings(values: list[float]) -> list[str]:
	if not values:
		return ["N/A", "N/A", "N/A", "N/A"]
	return [_format_pct(x) for x in _cuisine_distribution(values)]


def _build_column_stats(
	metrics: list[CaseMetrics],
	n_total: int,
	*,
	feasible_found_count: int | None = None,
) -> dict[str, list[str]]:
	calorie_values = [m.calorie_deviation for m in metrics if m.calorie_deviation is not None]
	protein_values = [m.protein_deviation for m in metrics if m.protein_deviation is not None]
	carb_values = [m.carb_deviation for m in metrics if m.carb_deviation is not None]
	fat_values = [m.fat_deviation for m in metrics if m.fat_deviation is not None]
	cuisine_values = [m.cuisine_score for m in metrics if m.cuisine_score is not None]
	fibre_values = [m.fibre_deviation for m in metrics if m.fibre_deviation is not None]

	n_found = feasible_found_count if feasible_found_count is not None else len(metrics)
	n_not_found = n_total - n_found
	found_pct = (n_found / n_total * 100.0) if n_total > 0 else 0.0
	not_found_pct = (n_not_found / n_total * 100.0) if n_total > 0 else 0.0

	return {
		"calorie": _distribution_strings(calorie_values, (10.0, 50.0, 100.0)),
		"protein": _distribution_strings(protein_values, (5.0, 25.0, 50.0)),
		"carb": _distribution_strings(carb_values, (2.0, 5.0, 50.0)),
		"fat": _distribution_strings(fat_values, (5.0, 10.0, 25.0)),
		"cuisine": _cuisine_strings(cuisine_values),
		"fibre": _distribution_strings(fibre_values, (5.0, 10.0, 15.0)),
		"feasible_found": [_format_pct(found_pct)],
		"feasible_not_found": [_format_pct(not_found_pct)],
	}


def _cell(cols: list[dict[str, list[str]]], metric_key: str, bucket_index: int) -> str:
	values = [col[metric_key][bucket_index] for col in cols]
	return " & ".join(values)


def _should_bold_for_metric(metric_key: str, llm_value: str, ano_value: str, bucket_index: int) -> bool:
	"""Determine if ANO value should be bolded (ANO is better than LLM-only).
	
	For threshold bands:
	- Bucket 0 (strictest, <=X): Higher % is better
	- Buckets 1-3 (looser): Lower % is better
	Except for cuisine and feasible_found, where higher is always better.
	"""
	try:
		llm_num = float(llm_value.strip().rstrip("\\%"))
		ano_num = float(ano_value.strip().rstrip("\\%"))
	except ValueError:
		return False

	if metric_key in ("cuisine", "feasible_found"):
		# Higher is always better
		return ano_num > llm_num
	elif metric_key == "feasible_not_found":
		# Lower is always better
		return ano_num < llm_num
	else:
		# For deviations: bucket 0 is strict (higher %), others are worse cases (lower %)
		if bucket_index == 0:
			# Strict threshold: higher % is better
			return ano_num > llm_num
		else:
			# Looser thresholds: lower % is better (fewer profiles need looser bands)
			return ano_num < llm_num


def _cell_with_bold(cols: list[dict[str, list[str]]], metric_key: str, bucket_index: int) -> str:
	"""Build cell with bold formatting for ANO columns when ANO is better."""
	values = [col[metric_key][bucket_index] for col in cols]
	bolded_values = []
	for i, value in enumerate(values):
		# Column indices: 0=DRI LLM, 1=DRI ANO, 2=MFP LLM, 3=MFP ANO
		# ANO columns are at indices 1 and 3
		if i == 1:  # DRI ANO
			if _should_bold_for_metric(metric_key, values[0], value, bucket_index):
				value = f"\\textbf{{{value}}}"
		elif i == 3:  # MFP ANO
			if _should_bold_for_metric(metric_key, values[2], value, bucket_index):
				value = f"\\textbf{{{value}}}"
		bolded_values.append(value)
	return " & ".join(bolded_values)


def _build_table_rows(cols: list[dict[str, list[str]]]) -> list[str]:
	br = r"\\"
	return [
		f"All & & & & {br}",
		"\\hline",
		f"\\quad Feasible solution found & {_cell_with_bold(cols, 'feasible_found', 0)} {br}",
		"\\hline",
		f"\\quad Feasible solution not found & {_cell_with_bold(cols, 'feasible_not_found', 0)} {br}",
		"\\hline",
		f"Calorie Deviation (kcal) & & & & {br}",
		"\\hline",
		f"\\quad $\\leq 10$ kcal & {_cell_with_bold(cols, 'calorie', 0)} {br}",
		"\\hline",
		f"\\quad $10$--$50$ kcal & {_cell_with_bold(cols, 'calorie', 1)} {br}",
		"\\hline",
		f"\\quad $50$--$100$ kcal & {_cell_with_bold(cols, 'calorie', 2)} {br}",
		"\\hline",
		f"\\quad $> 100$ kcal & {_cell_with_bold(cols, 'calorie', 3)} {br}",
		"\\hline",
		f"Protein Deviation (g) & & & & {br}",
		"\\hline",
		f"\\quad $\\leq 5$ g & {_cell_with_bold(cols, 'protein', 0)} {br}",
		"\\hline",
		f"\\quad $5$--$25$ g & {_cell_with_bold(cols, 'protein', 1)} {br}",
		"\\hline",
		f"\\quad $25$--$50$ g & {_cell_with_bold(cols, 'protein', 2)} {br}",
		"\\hline",
		f"\\quad $> 50$ g & {_cell_with_bold(cols, 'protein', 3)} {br}",
		"\\hline",
		f"Carbohydrate Deviation (g) & & & & {br}",
		"\\hline",
		f"\\quad $\\leq 2$ g & {_cell_with_bold(cols, 'carb', 0)} {br}",
		"\\hline",
		f"\\quad $2$--$5$ g & {_cell_with_bold(cols, 'carb', 1)} {br}",
		"\\hline",
		f"\\quad $5$--$50$ g & {_cell_with_bold(cols, 'carb', 2)} {br}",
		"\\hline",
		f"\\quad $> 50$ g & {_cell_with_bold(cols, 'carb', 3)} {br}",
		"\\hline",
		f"Fat Deviation (g) & & & & {br}",
		"\\hline",
		f"\\quad $\\leq 5$ g & {_cell_with_bold(cols, 'fat', 0)} {br}",
		"\\hline",
		f"\\quad $5$--$10$ g & {_cell_with_bold(cols, 'fat', 1)} {br}",
		"\\hline",
		f"\\quad $10$--$25$ g & {_cell_with_bold(cols, 'fat', 2)} {br}",
		"\\hline",
		f"\\quad $> 25$ g & {_cell_with_bold(cols, 'fat', 3)} {br}",
		"\\hline",
		f"Cuisine Alignment Score & & & & {br}",
		"\\hline",
		f"\\quad $< 0.25$ & {_cell_with_bold(cols, 'cuisine', 0)} {br}",
		"\\hline",
		f"\\quad $0.25$--$0.50$ & {_cell_with_bold(cols, 'cuisine', 1)} {br}",
		"\\hline",
		f"\\quad $0.50$--$0.75$ & {_cell_with_bold(cols, 'cuisine', 2)} {br}",
		"\\hline",
		f"\\quad $\\geq 0.75$ & {_cell_with_bold(cols, 'cuisine', 3)} {br}",
		"\\hline",
		f"Dietary Fiber Deviation (g) & & & & {br}",
		"\\hline",
		f"\\quad $\\leq 5$ g & {_cell_with_bold(cols, 'fibre', 0)} {br}",
		"\\hline",
		f"\\quad $5$--$10$ g & {_cell_with_bold(cols, 'fibre', 1)} {br}",
		"\\hline",
		f"\\quad $10$--$15$ g & {_cell_with_bold(cols, 'fibre', 2)} {br}",
		"\\hline",
		f"\\quad $> 15$ g & {_cell_with_bold(cols, 'fibre', 3)} {br}",
		"\\hline",
	]


def _update_last_table(tex_path: Path, rows: list[str]) -> None:
	original = tex_path.read_text(encoding="utf-8")

	label_token = "\\label{tab:category-results-combined}"
	label_index = original.find(label_token)
	if label_index < 0:
		raise ValueError("Could not locate tab:category-results-combined label in TeX file.")

	table_start = original.rfind("\\begin{table*}", 0, label_index)
	table_end = original.find("\\end{table*}", label_index)
	if table_start < 0 or table_end < 0:
		raise ValueError("Could not isolate the target table block in TeX file.")
	table_end += len("\\end{table*}")

	table_block = original[table_start:table_end]
	tabular_start_token = "\\begin{tabular}{|l|c|c|c|c|}"
	tabular_end_token = "\\end{tabular}"
	tabular_start = table_block.find(tabular_start_token)
	tabular_end = table_block.find(tabular_end_token)
	if tabular_start < 0 or tabular_end < 0:
		raise ValueError("Could not find tabular block inside target table.")

	tabular_end += len(tabular_end_token)
	tabular_header = (
		tabular_start_token
		+ "\n"
		+ "\\hline\n"
		+ r"\multicolumn{1}{|c|}{} & \multicolumn{2}{c|}{\textbf{Canada Health DRI}} & \multicolumn{2}{c|}{\textbf{MyFitnessPal}} \\"
		+ "\n"
		+ "\\hline\n"
		+ r"\textbf{Metric} & \textbf{LLM Only} & \textbf{ANO} & \textbf{LLM Only} & \textbf{ANO} \\"
		+ "\n"
		+ "\\hline\n"
	)
	new_body = "\n".join(rows) + "\n"
	updated_tabular = tabular_header + new_body + tabular_end_token
	updated_table_block = table_block[:tabular_start] + updated_tabular + table_block[tabular_end:]

	updated = original[:table_start] + updated_table_block + original[table_end:]
	tex_path.write_text(updated, encoding="utf-8")


def _collect_metrics_for_column(
	case_results: dict[str, dict[str, Any]],
	*,
	target_mode: str,
	achieved_mode: str,
	allow_dri_fallback_for_mfp: bool,
) -> tuple[list[CaseMetrics], int]:
	rows: list[CaseMetrics] = []
	total = 0
	for eval_id in sorted(case_results):
		total += 1
		case_result = case_results[eval_id]
		metrics = _extract_case_metrics(
			case_result,
			target_mode=target_mode,
			achieved_mode=achieved_mode,
			allow_dri_fallback_for_mfp=allow_dri_fallback_for_mfp,
		)
		if metrics is not None:
			rows.append(metrics)
	return rows, total


def build_parser() -> argparse.ArgumentParser:
	parser = argparse.ArgumentParser(
		description="Aggregate eval-history chats and update the final threshold table in the LaTeX paper.",
	)
	parser.add_argument("--canada-history-dir", default=str(CANADA_HISTORY_DIR))
	parser.add_argument("--llm-history-dir", default=str(LLM_HISTORY_DIR))
	parser.add_argument("--mfp-history-dir", default=str(MFP_HISTORY_DIR))
	parser.add_argument("--paper-tex", default=str(PAPER_TEX_PATH))
	parser.add_argument(
		"--dry-run",
		action="store_true",
		help="Print computed table rows without editing the TeX file.",
	)
	return parser


def main() -> int:
	args = build_parser().parse_args()

	canada_history_dir = Path(args.canada_history_dir).expanduser().resolve()
	llm_history_dir = Path(args.llm_history_dir).expanduser().resolve()
	mfp_history_dir = Path(args.mfp_history_dir).expanduser().resolve()
	paper_tex_path = Path(args.paper_tex).expanduser().resolve()

	for path in (canada_history_dir, llm_history_dir, mfp_history_dir):
		if not path.exists() or not path.is_dir():
			raise SystemExit(f"Eval history directory not found: {path}")
	if not paper_tex_path.exists():
		raise SystemExit(f"LaTeX file not found: {paper_tex_path}")

	canada_target_cases = _load_latest_case_results(canada_history_dir, DRI_EVAL_SET_ID)
	llm_target_cases = _load_latest_case_results(llm_history_dir, DRI_EVAL_SET_ID)
	llm_mfp_cases = _load_latest_case_results(llm_history_dir, MFP_MACRO_EVAL_SET_ID)
	mfp_cases = _load_latest_case_results(mfp_history_dir, MFP_MACRO_EVAL_SET_ID)

	dri_llm_metrics, dri_llm_total = _collect_metrics_for_column(
		llm_target_cases,
		target_mode="dri",
		achieved_mode="average_tool",
		allow_dri_fallback_for_mfp=False,
	)
	dri_ano_metrics, dri_ano_total = _collect_metrics_for_column(
		canada_target_cases,
		target_mode="dri",
		achieved_mode="optimizer_dri",
		allow_dri_fallback_for_mfp=False,
	)
	mfp_llm_metrics, mfp_llm_total = _collect_metrics_for_column(
		llm_mfp_cases,
		target_mode="mfp",
		achieved_mode="average_tool",
		allow_dri_fallback_for_mfp=True,
	)
	mfp_ano_metrics, mfp_ano_total = _collect_metrics_for_column(
		mfp_cases,
		target_mode="mfp",
		achieved_mode="optimizer_mfp",
		allow_dri_fallback_for_mfp=False,
	)

	dri_ano_feasible_found = sum(
		1 for case in canada_target_cases.values() if _has_strict_optimizer_feasible_result(case, "optimize_quantity")
	)
	mfp_ano_feasible_found = sum(
		1
		for case in mfp_cases.values()
		if _has_strict_optimizer_feasible_result(case, "optimize_quantity_for_mfp_targets")
	)
	dri_llm_feasible_found = sum(1 for m in dri_llm_metrics if _is_llm_feasible(m))
	mfp_llm_feasible_found = sum(1 for m in mfp_llm_metrics if _is_llm_feasible(m))

	column_stats = [
		_build_column_stats(dri_llm_metrics, dri_llm_total, feasible_found_count=dri_llm_feasible_found),
		_build_column_stats(dri_ano_metrics, dri_ano_total, feasible_found_count=dri_ano_feasible_found),
		_build_column_stats(mfp_llm_metrics, mfp_llm_total, feasible_found_count=mfp_llm_feasible_found),
		_build_column_stats(mfp_ano_metrics, mfp_ano_total, feasible_found_count=mfp_ano_feasible_found),
	]
	rows = _build_table_rows(column_stats)

	if args.dry_run:
		print("\n".join(rows))
	else:
		_update_last_table(paper_tex_path, rows)
		print(f"Updated table tab:category-results-combined in {paper_tex_path}")

	print("\nComputed case counts per column:")
	print(f"- Canada DRI / LLM Only: {len(dri_llm_metrics)} / {dri_llm_total} (threshold-feasible: {dri_llm_feasible_found})")
	print(f"- Canada DRI / ANO: {len(dri_ano_metrics)} / {dri_ano_total} (strict optimizer feasible: {dri_ano_feasible_found})")
	print(f"- MyFitnessPal / LLM Only: {len(mfp_llm_metrics)} / {mfp_llm_total} (threshold-feasible: {mfp_llm_feasible_found})")
	print(f"- MyFitnessPal / ANO: {len(mfp_ano_metrics)} / {mfp_ano_total} (strict optimizer feasible: {mfp_ano_feasible_found})")
	return 0


if __name__ == "__main__":
	raise SystemExit(main())
