#!/usr/bin/env python3
from __future__ import annotations

import argparse
import csv
import json
import math
import re
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from google.protobuf import text_format
from tensorboard.plugins.projector.projector_config_pb2 import ProjectorConfig


SCRIPT_DIR = Path(__file__).resolve().parent
DEFAULT_TRACE_DIRS = (
    SCRIPT_DIR.parent / "OptimizerMFPNutritionPlanner" / ".adk" / "traces",
    SCRIPT_DIR.parent / "OptimizerCanadaDRIPlanner" / ".adk" / "traces",
)
DEFAULT_OUTPUT_DIR = SCRIPT_DIR / "output" / "meal_projector"
PROJECTOR_CONFIG_NAME = "projector_config.pbtxt"
TENSOR_FILE_NAME = "meal_macro_embeddings.tsv"
METADATA_FILE_NAME = "meal_metadata.tsv"
RAW_JSON_FILE_NAME = "meal_nodes.json"
SUMMARY_FILE_NAME = "summary.json"
TRACE_GLOB_PATTERNS = ("sample_*.trace.json", "*.trace.json")
NUMERIC_KEYS = ("calories", "protein", "carbohydrates", "total_fat", "total_fibre")


@dataclass
class MealNode:
    meal_name: str
    meal_id: str
    eval_id: str
    trace_file: str
    status: str
    cuisine: str
    day: str
    meal: str
    option: str
    prompt: str
    ingredient_summary: str
    ingredient_names: list[str]
    feature_values: dict[str, float]
    achieved_targets_per_day: dict[str, float]
    target_values: dict[str, float]


def _load_json(path: Path) -> Any:
    return json.loads(path.read_text(encoding="utf-8"))


def _safe_float(value: Any) -> float | None:
    if isinstance(value, (int, float)):
        if math.isfinite(float(value)):
            return float(value)
        return None
    if isinstance(value, str):
        stripped = value.strip().replace(",", "")
        if not stripped:
            return None
        try:
            parsed = float(stripped)
        except ValueError:
            return None
        if math.isfinite(parsed):
            return parsed
    return None


def _sorted_key(name: str) -> tuple[int, str]:
    match = re.search(r"(\d+)$", name)
    if match:
        return (int(match.group(1)), name)
    return (10**9, name)


def _iter_invocation_events(trace_payload: dict[str, Any]) -> list[dict[str, Any]]:
    case_result = trace_payload.get("eval_case_result")
    if not isinstance(case_result, dict):
        return []

    per_invocation = case_result.get("eval_metric_result_per_invocation")
    if not isinstance(per_invocation, list):
        return []

    events: list[dict[str, Any]] = []
    for invocation in per_invocation:
        if not isinstance(invocation, dict):
            continue
        actual_invocation = invocation.get("actual_invocation")
        if not isinstance(actual_invocation, dict):
            continue
        intermediate_data = actual_invocation.get("intermediate_data")
        if not isinstance(intermediate_data, dict):
            continue
        invocation_events = intermediate_data.get("invocation_events")
        if not isinstance(invocation_events, list):
            continue
        for event in invocation_events:
            if isinstance(event, dict):
                events.append(event)
    return events


def _extract_user_prompt(trace_payload: dict[str, Any]) -> str:
    case_result = trace_payload.get("eval_case_result")
    if not isinstance(case_result, dict):
        return ""

    per_invocation = case_result.get("eval_metric_result_per_invocation")
    if not isinstance(per_invocation, list):
        return ""

    for invocation in per_invocation:
        if not isinstance(invocation, dict):
            continue
        actual_invocation = invocation.get("actual_invocation")
        if not isinstance(actual_invocation, dict):
            continue
        user_content = actual_invocation.get("user_content")
        if not isinstance(user_content, dict):
            continue
        parts = user_content.get("parts")
        if not isinstance(parts, list):
            continue
        for part in parts:
            if not isinstance(part, dict):
                continue
            text = part.get("text")
            if isinstance(text, str) and text.strip():
                return text.strip()
    return ""


def _extract_text_parts(content: dict[str, Any]) -> list[str]:
    parts = content.get("parts")
    if not isinstance(parts, list):
        return []

    texts: list[str] = []
    for part in parts:
        if not isinstance(part, dict):
            continue
        text = part.get("text")
        if isinstance(text, str) and text.strip():
            texts.append(text.strip())
    return texts


def _extract_final_model_text(trace_payload: dict[str, Any]) -> str:
    case_result = trace_payload.get("eval_case_result")
    if isinstance(case_result, dict):
        per_invocation = case_result.get("eval_metric_result_per_invocation")
        if isinstance(per_invocation, list):
            for invocation in reversed(per_invocation):
                if not isinstance(invocation, dict):
                    continue
                actual_invocation = invocation.get("actual_invocation")
                if not isinstance(actual_invocation, dict):
                    continue
                final_response = actual_invocation.get("final_response")
                if isinstance(final_response, dict):
                    texts = _extract_text_parts(final_response)
                    if texts:
                        return "\n\n".join(texts)

    for event in reversed(_iter_invocation_events(trace_payload)):
        content = event.get("content")
        if not isinstance(content, dict):
            continue
        if content.get("role") != "model":
            continue
        texts = _extract_text_parts(content)
        if texts:
            return "\n\n".join(texts)
    return ""


def _extract_cuisine(prompt: str) -> str:
    match = re.search(r"i want to eat\s+(.+?)\s+food", prompt, flags=re.IGNORECASE)
    if not match:
        return "unknown"
    return match.group(1).strip().lower()


def _extract_numeric_suffix(value: str) -> int | None:
    match = re.search(r"(\d+)$", value)
    if not match:
        return None
    return int(match.group(1))


def _normalize_trace_line(line: str) -> str:
    normalized = line.strip()
    normalized = re.sub(r"^\s*#{1,6}\s*", "", normalized)
    normalized = normalized.replace("**", "").replace("__", "").replace("`", "")
    normalized = re.sub(r"^\s*[*-]\s*", "", normalized)
    return normalized.strip()


def _clean_meal_name(name: str) -> str:
    cleaned = _normalize_trace_line(name)
    cleaned = re.sub(r"\(same as day[^)]*\)", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(
        r"\((?:approx\.?|about|around|~)?\s*\d+(?:\.\d+)?\s*(?:g|kg|mg|ml|l|oz|lb|lbs|slice|slices|egg|eggs)\b[^)]*\)",
        "",
        cleaned,
        flags=re.IGNORECASE,
    )
    cleaned = re.sub(r"\s+", " ", cleaned)
    return cleaned.strip(" .:-")


def _meal_name_word_count(name: str) -> int:
    return len(re.findall(r"[A-Za-z0-9']+", name))


def _is_generic_meal_name(name: str) -> bool:
    normalized = _clean_meal_name(name).lower()
    if not normalized:
        return True
    if normalized in {
        "breakfast",
        "lunch",
        "dinner",
        "snack",
        "brunch",
        "morning",
        "mid-day",
        "midday",
        "evening",
    }:
        return True
    return re.fullmatch(r"meal\s+\d+", normalized) is not None


def _looks_like_ingredient_list(text: str) -> bool:
    normalized = _normalize_trace_line(text)
    if not normalized:
        return False
    if re.search(r"\b\d+(?:\.\d+)?\s*(?:g|kg|mg|ml|l|oz|lb|lbs)\b", normalized, flags=re.IGNORECASE):
        return True
    if normalized.count(",") >= 2:
        return True
    if ":" in normalized and not re.match(r"^(Meal|Breakfast|Lunch|Dinner|Snack|Brunch)\b", normalized, flags=re.IGNORECASE):
        return True
    return False


def _clean_ingredient_label(text: str) -> str:
    cleaned = _normalize_trace_line(text)
    cleaned = re.sub(r"\([^)]*\)", "", cleaned)
    cleaned = re.sub(r"\s+by\s+.+$", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\b\d+(?:\.\d+)?\s*(?:g|kg|mg|ml|l|oz|lb|lbs)\b", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\bapprox\.?\b", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\bfinely minced\b|\bminced\b|\bcut small\b|\bas a soft base\b", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\bmixed with\b.*$", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\bcooked with\b.*$", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\bfor aroma\b.*$", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\bas soup base\b.*$", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\s+", " ", cleaned)
    return cleaned.strip(" ,;:-")


def _ingredient_content_tokens(text: str) -> list[str]:
    tokens = re.findall(r"[A-Za-z0-9']+", _clean_ingredient_label(text))
    if not tokens:
        return []

    low_value_words = {
        "a",
        "an",
        "and",
        "ara",
        "as",
        "blend",
        "bowl",
        "brown",
        "breakfast",
        "cage",
        "chunk",
        "chunks",
        "cooked",
        "dha",
        "diced",
        "dinner",
        "dish",
        "extra",
        "fat",
        "feeding",
        "food",
        "foods",
        "fresh",
        "free",
        "frozen",
        "in",
        "inspired",
        "instant",
        "large",
        "low",
        "lowfat",
        "lunch",
        "mashed",
        "meal",
        "medley",
        "medium",
        "mini",
        "natural",
        "naturals",
        "nonfat",
        "of",
        "on",
        "organic",
        "original",
        "plain",
        "premium",
        "product",
        "pure",
        "puree",
        "pureed",
        "powder",
        "raised",
        "ready",
        "roasted",
        "salt",
        "salted",
        "salty",
        "seasoned",
        "select",
        "shell",
        "skinless",
        "snack",
        "soft",
        "sprouted",
        "style",
        "the",
        "toddler",
        "to",
        "toasted",
        "traditional",
        "oil",
        "whole",
        "with",
        "brand",
        "boy",
        "cinnamon",
        "golden",
        "hormel",
        "nutmeg",
        "sugar",
    }
    filtered = [token for token in tokens if token.lower() not in low_value_words]
    return filtered or tokens


def _is_condiment_like_phrase(text: str) -> bool:
    tokens = [token.lower() for token in _ingredient_content_tokens(text)]
    if not tokens:
        return False
    condiment_terms = {
        "dressing",
        "ghee",
        "miso",
        "oil",
        "paste",
        "salsa",
        "sauce",
        "seasoning",
        "soy",
        "tamari",
        "tahini",
    }
    if any(token in {"dressing", "dip", "salsa", "sauce", "seasoning"} for token in tokens):
        return True
    return all(token in condiment_terms for token in tokens)


def _compact_phrase(text: str, max_words: int) -> str:
    if max_words <= 0:
        return ""

    tokens = _ingredient_content_tokens(text)
    if not tokens:
        return ""
    if len(tokens) <= max_words:
        return " ".join(tokens)

    generic_tail_words = {"blend", "bowl", "dish", "food", "meal", "medley", "mix", "plate"}
    if max_words == 1:
        for token in reversed(tokens):
            if token.lower() not in generic_tail_words:
                return token
        return tokens[-1]

    candidate = tokens[-max_words:]
    while len(candidate) > 1 and candidate[-1].lower() in generic_tail_words:
        candidate = candidate[:-1]
    while len(candidate) < max_words:
        previous_index = len(tokens) - len(candidate) - 1
        if previous_index < 0:
            break
        previous_token = tokens[previous_index]
        if previous_token.lower() not in generic_tail_words:
            candidate.insert(0, previous_token)
        else:
            break
    return " ".join(candidate[-max_words:])


def _join_name_parts(parts: list[str]) -> str:
    if not parts:
        return ""
    if len(parts) == 1:
        return parts[0]
    if len(parts) == 2:
        return f"{parts[0]} & {parts[1]}"
    return f"{', '.join(parts[:-1])} & {parts[-1]}"


def _remove_forbidden_name_terms(name: str) -> str:
    if not name:
        return ""

    cleaned = _clean_meal_name(name)
    cleaned = re.sub(r"\b(?:oil|salt|salted|salty)\b", "", cleaned, flags=re.IGNORECASE)
    cleaned = re.sub(r"\s*,\s*", ", ", cleaned)
    cleaned = re.sub(r"\s*&\s*", " & ", cleaned)
    cleaned = re.sub(r"\s+", " ", cleaned)
    cleaned = cleaned.strip(" ,&:-")

    # Remove empty delimiters left after token deletion.
    cleaned = re.sub(r"(?:\s*&\s*){2,}", " & ", cleaned)
    cleaned = re.sub(r"(?:\s*,\s*){2,}", ", ", cleaned)
    cleaned = cleaned.strip(" ,&:-")
    return cleaned


def _shorten_meal_name(name: str, ingredient_names: list[str]) -> str:
    cleaned_name = _clean_meal_name(name)
    if _meal_name_word_count(cleaned_name) <= 4:
        return cleaned_name

    source_phrases = [phrase for phrase in ingredient_names if phrase and not _is_condiment_like_phrase(phrase)]
    if not source_phrases:
        source_phrases = [phrase.strip() for phrase in re.split(r",\s*|\s*&\s*", cleaned_name) if phrase.strip()]

    compact_parts: list[str] = []
    used_parts: set[str] = set()
    for phrase in source_phrases:
        compact = _compact_phrase(phrase, 1)
        if not compact:
            continue
        normalized = compact.lower()
        if normalized in used_parts:
            continue
        compact_parts.append(compact)
        used_parts.add(normalized)
        if len(compact_parts) == 4:
            break

    remaining_budget = 4 - sum(_meal_name_word_count(part) for part in compact_parts)
    for index, phrase in enumerate(source_phrases[: len(compact_parts)]):
        if remaining_budget <= 0:
            break
        expanded = _compact_phrase(phrase, 2)
        if not expanded:
            continue
        current = compact_parts[index]
        added_words = _meal_name_word_count(expanded) - _meal_name_word_count(current)
        if added_words <= 0 or added_words > remaining_budget:
            continue
        compact_parts[index] = expanded
        remaining_budget -= added_words

    shortened = _join_name_parts(compact_parts)
    if shortened and _meal_name_word_count(shortened) <= 4:
        return shortened

    fallback_tokens = _ingredient_content_tokens(cleaned_name)
    return " ".join(fallback_tokens[:4]) if fallback_tokens else cleaned_name


def _extract_ingredient_label_from_line(text: str) -> str:
    normalized = _normalize_trace_line(text)
    if not normalized:
        return ""

    candidate = normalized.split(":", 1)[0] if ":" in normalized else normalized
    candidate = re.sub(r"\([^)]*\).*", "", candidate)
    candidate = re.split(
        r"\b\d+(?:\.\d+)?\s*(?:g|kg|mg|ml|l|oz|lb|lbs)\b",
        candidate,
        maxsplit=1,
        flags=re.IGNORECASE,
    )[0]
    return _clean_ingredient_label(candidate)


def _split_ingredient_text(text: str) -> list[str]:
    normalized = _normalize_trace_line(text)
    if not normalized:
        return []

    candidate_text = normalized
    if ":" in candidate_text:
        candidate_text = candidate_text.split(":", 1)[1]

    pieces = [piece.strip() for piece in re.split(r",\s+", candidate_text) if piece.strip()]
    cleaned_pieces = [_clean_ingredient_label(piece) for piece in pieces]
    return [piece for piece in cleaned_pieces if piece]


def _synthesize_meal_name_from_ingredients(ingredients: list[str]) -> str:
    if not ingredients:
        return ""

    condiment_terms = (
        "oil",
        "sauce",
        "ghee",
        "miso",
        "paste",
        "dressing",
        "dip",
        "seasoning",
    )

    preferred = [
        ingredient
        for ingredient in ingredients
        if not any(term in ingredient.lower() for term in condiment_terms)
    ]
    selected = preferred or ingredients
    selected = selected[:3]

    if len(selected) == 1:
        return selected[0]
    if len(selected) == 2:
        return f"{selected[0]} & {selected[1]}"
    return f"{selected[0]}, {selected[1]} & {selected[2]}"


def _extract_following_ingredient_labels(lines: list[str], start_index: int) -> list[str]:
    ingredients: list[str] = []
    for raw_line in lines[start_index:]:
        line = _normalize_trace_line(raw_line)
        if not line:
            continue
        if _parse_day_numbers(line) is not None or _parse_meal_header(line) is not None:
            break
        if re.match(r"^(Note|Notes|Preparation Tip|Important Note)\b", line, flags=re.IGNORECASE):
            break
        ingredient = _extract_ingredient_label_from_line(line)
        if ingredient:
            ingredients.append(ingredient)
    return ingredients


def _parse_day_numbers(line: str) -> list[int] | None:
    if not re.match(r"^Days?\b", line, flags=re.IGNORECASE):
        return None

    range_match = re.match(r"^Days?\s+(\d+)\s*[\-–]\s*(\d+)\b", line, flags=re.IGNORECASE)
    if range_match:
        start_day = int(range_match.group(1))
        end_day = int(range_match.group(2))
        if start_day <= end_day:
            return list(range(start_day, end_day + 1))

    day_numbers = [int(piece) for piece in re.findall(r"\d+", line)]
    return day_numbers or None


def _parse_meal_header(line: str) -> tuple[int | None, str, str] | None:
    explicit_meal_match = re.match(
        r"^Meal\s+(\d+)\s*(?:\(([^)]+)\))?\s*:\s*(.*)$",
        line,
        flags=re.IGNORECASE,
    )
    if explicit_meal_match:
        slot_index = int(explicit_meal_match.group(1))
        return (
            slot_index,
            explicit_meal_match.group(2) or "",
            explicit_meal_match.group(3) or "",
        )

    titled_slot_match = re.match(
        r"^(Breakfast|Lunch|Dinner|Snack|Brunch)\s*\(([^)]+)\)\s*:\s*(.*)$",
        line,
        flags=re.IGNORECASE,
    )
    if titled_slot_match:
        slot_label = titled_slot_match.group(1).lower()
        slot_index = {
            "breakfast": 1,
            "lunch": 2,
            "dinner": 3,
            "snack": 4,
            "brunch": 2,
        }.get(slot_label)
        return (
            slot_index,
            titled_slot_match.group(2),
            titled_slot_match.group(3) or "",
        )

    slot_name_match = re.match(
        r"^(Breakfast|Lunch|Dinner|Snack|Brunch)\s*:\s*(.*)$",
        line,
        flags=re.IGNORECASE,
    )
    if slot_name_match:
        slot_label = slot_name_match.group(1).lower()
        slot_index = {
            "breakfast": 1,
            "lunch": 2,
            "dinner": 3,
            "snack": 4,
            "brunch": 2,
        }.get(slot_label)
        return (slot_index, slot_name_match.group(1), slot_name_match.group(2))

    return None


def _resolve_meal_name_from_message(lines: list[str], index: int, label_text: str, detail_text: str) -> str:
    cleaned_detail = _clean_meal_name(detail_text)
    if cleaned_detail and not _is_generic_meal_name(cleaned_detail) and not _looks_like_ingredient_list(detail_text):
        return cleaned_detail

    detail_ingredients = _split_ingredient_text(detail_text)
    synthesized_detail = _synthesize_meal_name_from_ingredients(detail_ingredients)
    if synthesized_detail:
        return synthesized_detail

    cleaned_label = _clean_meal_name(label_text)
    if cleaned_label and not _is_generic_meal_name(cleaned_label):
        return cleaned_label

    following_ingredients = _extract_following_ingredient_labels(lines, index + 1)
    synthesized_following = _synthesize_meal_name_from_ingredients(following_ingredients)
    if synthesized_following:
        return synthesized_following

    return cleaned_detail or cleaned_label


def _extract_meal_names_by_day(final_response_text: str) -> tuple[dict[tuple[int, int], str], dict[int, str]]:
    meal_names_by_day: dict[tuple[int, int], str] = {}
    fallback_names: dict[int, str] = {}
    current_days: list[int] = []
    lines = final_response_text.splitlines()

    for index, raw_line in enumerate(lines):
        line = _normalize_trace_line(raw_line)
        if not line:
            continue

        day_numbers = _parse_day_numbers(line)
        if day_numbers is not None:
            current_days = day_numbers
            continue

        parsed_header = _parse_meal_header(line)
        if parsed_header is None:
            continue

        slot_index, label_text, detail_text = parsed_header
        if slot_index is None:
            continue
        meal_name = _resolve_meal_name_from_message(lines, index, label_text, detail_text)
        if not meal_name:
            continue
        if current_days:
            for day_number in current_days:
                meal_names_by_day[(day_number, slot_index)] = meal_name
        else:
            fallback_names[slot_index] = meal_name

    return meal_names_by_day, fallback_names


def _resolve_meal_name(
    day_key: str,
    meal_key: str,
    meal_names_by_day: dict[tuple[int, int], str],
    fallback_names: dict[int, str],
) -> str:
    day_index = _extract_numeric_suffix(day_key)
    meal_index = _extract_numeric_suffix(meal_key)
    if meal_index is None:
        meal_index = {
            "breakfast": 1,
            "lunch": 2,
            "dinner": 3,
            "snack": 4,
            "brunch": 2,
        }.get(meal_key.strip().lower())

    if day_index is not None and meal_index is not None:
        meal_name = meal_names_by_day.get((day_index, meal_index))
        if meal_name:
            return meal_name
    if meal_index is not None:
        meal_name = fallback_names.get(meal_index)
        if meal_name:
            return meal_name
    return meal_key.replace("_", " ").title()


def _fallback_meal_name_from_ingredients(ingredient_names: list[str]) -> str:
    cleaned_names = [_clean_ingredient_label(name) for name in ingredient_names if _clean_ingredient_label(name)]
    return _synthesize_meal_name_from_ingredients(cleaned_names)


def _extract_function_responses(
    events: list[dict[str, Any]],
    function_name: str,
) -> list[dict[str, Any]]:
    responses: list[dict[str, Any]] = []
    for event in events:
        content = event.get("content")
        if not isinstance(content, dict):
            continue
        parts = content.get("parts")
        if not isinstance(parts, list):
            continue
        for part in parts:
            if not isinstance(part, dict):
                continue
            function_response = part.get("function_response")
            if not isinstance(function_response, dict):
                continue
            if function_response.get("name") != function_name:
                continue
            response = function_response.get("response")
            if isinstance(response, dict):
                responses.append(response)
    return responses


def _unwrap_result(payload: dict[str, Any]) -> dict[str, Any]:
    nested = payload.get("result")
    return nested if isinstance(nested, dict) else payload


def _extract_product_lookup(events: list[dict[str, Any]]) -> dict[int, dict[str, Any]]:
    lookup: dict[int, dict[str, Any]] = {}
    for response in _extract_function_responses(events, "find_ingredient"):
        payload = _unwrap_result(response)
        result_items = payload.get("result")
        if not isinstance(result_items, list):
            continue
        for query_result in result_items:
            if not isinstance(query_result, dict):
                continue
            products = query_result.get("results")
            if not isinstance(products, list):
                continue
            for product in products:
                if not isinstance(product, dict):
                    continue
                index_value = product.get("index")
                if isinstance(index_value, int):
                    lookup[index_value] = product
    return lookup


def _extract_target_values(events: list[dict[str, Any]]) -> dict[str, float]:
    responses = _extract_function_responses(events, "calculate_health_canada_dri")
    if not responses:
        return {}

    payload = _unwrap_result(responses[-1])
    targets: dict[str, float] = {}

    calories = payload.get("calories")
    if isinstance(calories, dict):
        calorie_value = _safe_float(calories.get("eer_kcal"))
        if calorie_value is not None:
            targets["calories"] = calorie_value
    else:
        calorie_value = _safe_float(calories)
        if calorie_value is not None:
            targets["calories"] = calorie_value

    recommended = payload.get("recommended_g_per_day")
    if isinstance(recommended, dict):
        for source_key, target_key in (
            ("calories eer_kcal", "calories"),
            ("calories_eer_kcal", "calories"),
            ("eer_kcal", "calories"),
            ("protein", "protein"),
            ("carbohydrates", "carbohydrates"),
            ("total_fat", "total_fat"),
            ("fat", "total_fat"),
            ("total_fibre", "total_fibre"),
            ("fibre", "total_fibre"),
        ):
            value = _safe_float(recommended.get(source_key))
            if value is not None:
                targets[target_key] = value

    macro_ranges = payload.get("recommended_macronutrient_ranges_g_per_day")
    if isinstance(macro_ranges, dict):
        for source_key, target_key in (
            ("protein", "protein_range"),
            ("carbohydrates", "carbohydrates_range"),
            ("total_fat", "total_fat_range"),
            ("fat", "total_fat_range"),
        ):
            range_value = macro_ranges.get(source_key)
            if isinstance(range_value, dict):
                lower = _safe_float(range_value.get("lower"))
                upper = _safe_float(range_value.get("upper"))
                if lower is not None and upper is not None:
                    targets[f"{target_key}_lower"] = lower
                    targets[f"{target_key}_upper"] = upper

    return targets


def _extract_latest_optimizer_payload(events: list[dict[str, Any]]) -> dict[str, Any] | None:
    candidates: list[dict[str, Any]] = []
    for response in _extract_function_responses(events, "optimize_quantity"):
        payload = _unwrap_result(response)
        quantities = payload.get("calculated_quantity_per_day")
        if isinstance(quantities, dict):
            candidates.append(payload)

    if not candidates:
        return None

    for payload in reversed(candidates):
        status = payload.get("status")
        if isinstance(status, str) and status.lower() in {"feasible", "optimal", "ok"}:
            return payload
    return None


def _extract_achieved_targets(payload: dict[str, Any]) -> dict[str, float]:
    achieved = payload.get("achieved_targets_per_day")
    if not isinstance(achieved, dict):
        return {}
    result: dict[str, float] = {}
    for key in NUMERIC_KEYS:
        value = _safe_float(achieved.get(key))
        if value is not None:
            result[key] = value
    return result


def _nutrition_for_product(product: dict[str, Any]) -> dict[str, float]:
    nutrition = product.get("nutrition")
    if not isinstance(nutrition, dict):
        return {}

    values: dict[str, float] = {}
    for key in NUMERIC_KEYS:
        numeric_value = _safe_float(nutrition.get(key))
        if numeric_value is not None:
            values[key] = numeric_value
    return values


def _build_ingredient_summary(items: list[dict[str, Any]], lookup: dict[int, dict[str, Any]]) -> tuple[str, list[str]]:
    labels: list[str] = []
    names: list[str] = []
    for item in items:
        if not isinstance(item, dict):
            continue
        index_value = item.get("index")
        qty_value = _safe_float(item.get("qty"))
        product = lookup.get(index_value) if isinstance(index_value, int) else None
        name = product.get("name") if isinstance(product, dict) else f"index_{index_value}"
        if isinstance(name, str):
            names.append(name)
        qty_label = f"{qty_value:.1f}g" if qty_value is not None else "qty unknown"
        labels.append(f"{name} ({qty_label})")
    return "; ".join(labels), names


def _compute_meal_features(items: list[dict[str, Any]], lookup: dict[int, dict[str, Any]]) -> dict[str, float]:
    totals = {key: 0.0 for key in NUMERIC_KEYS}
    for item in items:
        if not isinstance(item, dict):
            continue
        index_value = item.get("index")
        qty_value = _safe_float(item.get("qty"))
        if not isinstance(index_value, int) or qty_value is None:
            continue
        nutrition = _nutrition_for_product(lookup.get(index_value, {}))
        for key in NUMERIC_KEYS:
            totals[key] += qty_value * nutrition.get(key, 0.0) / 100.0

    calories = totals["calories"]
    protein_share = (totals["protein"] * 4.0 / calories) if calories > 0 else 0.0
    carb_share = (totals["carbohydrates"] * 4.0 / calories) if calories > 0 else 0.0
    fat_share = (totals["total_fat"] * 9.0 / calories) if calories > 0 else 0.0

    return {
        **totals,
        "protein_calorie_share": protein_share,
        "carb_calorie_share": carb_share,
        "fat_calorie_share": fat_share,
    }


def _extract_meal_nodes(trace_path: Path) -> list[MealNode]:
    payload = _load_json(trace_path)
    events = _iter_invocation_events(payload)
    optimizer_payload = _extract_latest_optimizer_payload(events)
    if optimizer_payload is None:
        return []

    product_lookup = _extract_product_lookup(events)
    quantities = optimizer_payload.get("calculated_quantity_per_day")
    if not isinstance(quantities, dict):
        return []

    prompt = _extract_user_prompt(payload)
    final_response_text = _extract_final_model_text(payload)
    meal_names_by_day, fallback_meal_names = _extract_meal_names_by_day(final_response_text)
    cuisine = _extract_cuisine(prompt)
    eval_id = str(payload.get("eval_id") or payload.get("eval_case_result", {}).get("eval_id") or trace_path.stem)
    status = str(optimizer_payload.get("status") or "unknown")
    achieved_targets = _extract_achieved_targets(optimizer_payload)
    target_values = _extract_target_values(events)

    nodes: list[MealNode] = []
    for day_key in sorted(quantities, key=_sorted_key):
        day_payload = quantities.get(day_key)
        if not isinstance(day_payload, dict):
            continue
        for meal_key in sorted(day_payload, key=_sorted_key):
            meal_payload = day_payload.get(meal_key)
            if not isinstance(meal_payload, dict) or not meal_payload:
                continue
            option_name, option_items = next(iter(meal_payload.items()))
            if not isinstance(option_items, list):
                continue
            feature_values = _compute_meal_features(option_items, product_lookup)
            ingredient_summary, ingredient_names = _build_ingredient_summary(option_items, product_lookup)
            meal_id = f"{eval_id}::{day_key}::{meal_key}"
            meal_name = _resolve_meal_name(day_key, meal_key, meal_names_by_day, fallback_meal_names)
            if _is_generic_meal_name(meal_name):
                fallback_meal_name = _fallback_meal_name_from_ingredients(ingredient_names)
                if fallback_meal_name:
                    meal_name = fallback_meal_name
            meal_name = _shorten_meal_name(meal_name, ingredient_names)
            meal_name = _remove_forbidden_name_terms(meal_name)

            if not meal_name:
                fallback_meal_name = _fallback_meal_name_from_ingredients(ingredient_names)
                meal_name = _shorten_meal_name(fallback_meal_name or meal_key.replace("_", " "), ingredient_names)
                meal_name = _remove_forbidden_name_terms(meal_name)

            if _meal_name_word_count(meal_name) > 4:
                meal_name = _remove_forbidden_name_terms(_shorten_meal_name(meal_name, ingredient_names))
            nodes.append(
                MealNode(
                    meal_name=meal_name,
                    meal_id=meal_id,
                    eval_id=eval_id,
                    trace_file=trace_path.name,
                    status=status,
                    cuisine=cuisine,
                    day=day_key,
                    meal=meal_key,
                    option=str(option_name),
                    prompt=prompt,
                    ingredient_summary=ingredient_summary,
                    ingredient_names=ingredient_names,
                    feature_values=feature_values,
                    achieved_targets_per_day=achieved_targets,
                    target_values=target_values,
                )
            )
    return nodes


def _zscore_rows(nodes: list[MealNode], feature_names: list[str]) -> list[list[float]]:
    if not nodes:
        return []

    means: dict[str, float] = {}
    stds: dict[str, float] = {}
    modes: dict[str, float] = {}
    
    for feature_name in feature_names:
        values = [node.feature_values.get(feature_name, 0.0) for node in nodes]
        mean = sum(values) / len(values)
        variance = sum((value - mean) ** 2 for value in values) / len(values)
        std = math.sqrt(variance)
        means[feature_name] = mean
        stds[feature_name] = std if std > 1e-9 else 1.0
        
        # Compute mode as the point of highest density using median + density estimation
        sorted_values = sorted(values)
        # Use a simple density-based mode: find the center of the densest histogram bin
        num_bins = max(int(len(values) ** 0.5), 5)
        bin_size = (max(values) - min(values)) / num_bins if max(values) != min(values) else 1.0
        bin_centers = []
        bin_counts = []
        
        for i in range(num_bins):
            bin_start = min(values) + i * bin_size
            bin_end = bin_start + bin_size
            count = sum(1 for v in values if bin_start <= v < bin_end)
            if i == num_bins - 1:  # Include right edge in last bin
                count = sum(1 for v in values if bin_start <= v <= bin_end)
            bin_centers.append((bin_start + bin_end) / 2)
            bin_counts.append(count)
        
        # Mode is the center of the densest bin
        max_bin_idx = bin_counts.index(max(bin_counts))
        modes[feature_name] = bin_centers[max_bin_idx]

    rows: list[list[float]] = []
    for node in nodes:
        row = []
        for feature_name in feature_names:
            raw_value = node.feature_values.get(feature_name, 0.0)
            # Center around the mode instead of the mean
            row.append((raw_value - modes[feature_name]) / stds[feature_name])
        rows.append(row)
    return rows


def _write_tensor_file(output_dir: Path, rows: list[list[float]]) -> Path:
    tensor_path = output_dir / TENSOR_FILE_NAME
    with tensor_path.open("w", encoding="utf-8", newline="") as handle:
        for row in rows:
            handle.write("\t".join(f"{value:.8f}" for value in row))
            handle.write("\n")
    return tensor_path


def _write_metadata_file(output_dir: Path, nodes: list[MealNode]) -> Path:
    metadata_path = output_dir / METADATA_FILE_NAME
    fieldnames = [
        "meal_name",
        "trace_file",
        "status",
        "cuisine",
        "calories",
        "protein_g",
        "carbohydrates_g",
        "total_fat_g",
        "total_fibre_g",
        "protein_calorie_share",
        "carb_calorie_share",
        "fat_calorie_share",
        "ingredient_summary",
    ]
    with metadata_path.open("w", encoding="utf-8", newline="") as handle:
        writer = csv.DictWriter(handle, fieldnames=fieldnames, delimiter="\t")
        writer.writeheader()
        for node in nodes:
            writer.writerow(
                {
                    "meal_name": node.meal_name,
                    "trace_file": node.trace_file,
                    "status": node.status,
                    "cuisine": node.cuisine,
                    "calories": f"{node.feature_values.get('calories', 0.0):.4f}",
                    "protein_g": f"{node.feature_values.get('protein', 0.0):.4f}",
                    "carbohydrates_g": f"{node.feature_values.get('carbohydrates', 0.0):.4f}",
                    "total_fat_g": f"{node.feature_values.get('total_fat', 0.0):.4f}",
                    "total_fibre_g": f"{node.feature_values.get('total_fibre', 0.0):.4f}",
                    "protein_calorie_share": f"{node.feature_values.get('protein_calorie_share', 0.0):.6f}",
                    "carb_calorie_share": f"{node.feature_values.get('carb_calorie_share', 0.0):.6f}",
                    "fat_calorie_share": f"{node.feature_values.get('fat_calorie_share', 0.0):.6f}",
                    "ingredient_summary": node.ingredient_summary,
                }
            )
    return metadata_path


def _write_projector_config(output_dir: Path, tensor_path: Path, metadata_path: Path) -> Path:
    config = ProjectorConfig()
    embedding = config.embeddings.add()
    embedding.tensor_name = "meal_macro_embeddings"
    embedding.tensor_path = tensor_path.name
    embedding.metadata_path = metadata_path.name

    config_path = output_dir / PROJECTOR_CONFIG_NAME
    config_path.write_text(text_format.MessageToString(config), encoding="utf-8")
    return config_path


def _write_raw_json(output_dir: Path, nodes: list[MealNode], feature_names: list[str]) -> Path:
    raw_path = output_dir / RAW_JSON_FILE_NAME
    payload = {
        "feature_names": feature_names,
        "meals": [
            {
                "meal_name": node.meal_name,
                "meal_id": node.meal_id,
                "eval_id": node.eval_id,
                "trace_file": node.trace_file,
                "status": node.status,
                "cuisine": node.cuisine,
                "day": node.day,
                "meal": node.meal,
                "option": node.option,
                "prompt": node.prompt,
                "ingredient_summary": node.ingredient_summary,
                "ingredient_names": node.ingredient_names,
                "feature_values": node.feature_values,
                "achieved_targets_per_day": node.achieved_targets_per_day,
                "target_values": node.target_values,
            }
            for node in nodes
        ],
    }
    raw_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
    return raw_path


def _write_summary(
    output_dir: Path,
    trace_dirs: list[Path],
    trace_files: list[Path],
    nodes: list[MealNode],
    feature_names: list[str],
) -> Path:
    summary = {
        "trace_directories": [str(path) for path in trace_dirs],
        "trace_files": [path.name for path in trace_files],
        "trace_count": len(trace_files),
        "meal_count": len(nodes),
        "feature_names": feature_names,
        "cuisines": sorted({node.cuisine for node in nodes}),
        "statuses": sorted({node.status for node in nodes}),
    }
    summary_path = output_dir / SUMMARY_FILE_NAME
    summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
    return summary_path


def _discover_trace_files(trace_dirs: list[Path]) -> list[Path]:
    all_matches: list[Path] = []
    for trace_dir in trace_dirs:
        for pattern in TRACE_GLOB_PATTERNS:
            matches = sorted(trace_dir.glob(pattern))
            if matches:
                all_matches.extend(matches)
                break
    return sorted(all_matches, key=lambda path: path.name)


def build_projector_dataset(trace_dirs: list[Path], output_dir: Path, clean: bool) -> dict[str, Any]:
    trace_files = _discover_trace_files(trace_dirs)
    if not trace_files:
        joined = ", ".join(str(path) for path in trace_dirs)
        raise FileNotFoundError(f"No trace files found in: {joined}")

    if clean and output_dir.exists():
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    nodes: list[MealNode] = []
    for trace_path in trace_files:
        nodes.extend(_extract_meal_nodes(trace_path))

    if not nodes:
        raise RuntimeError("No optimized meals were extracted from the provided traces")

    feature_names = [
        "calories",
        "protein",
        "carbohydrates",
        "total_fat",
        "total_fibre",
        "protein_calorie_share",
        "carb_calorie_share",
        "fat_calorie_share",
    ]
    rows = _zscore_rows(nodes, feature_names)
    tensor_path = _write_tensor_file(output_dir, rows)
    metadata_path = _write_metadata_file(output_dir, nodes)
    config_path = _write_projector_config(output_dir, tensor_path, metadata_path)
    raw_json_path = _write_raw_json(output_dir, nodes, feature_names)
    summary_path = _write_summary(output_dir, trace_dirs, trace_files, nodes, feature_names)

    return {
        "trace_count": len(trace_files),
        "meal_count": len(nodes),
        "tensor_path": str(tensor_path),
        "metadata_path": str(metadata_path),
        "config_path": str(config_path),
        "raw_json_path": str(raw_json_path),
        "summary_path": str(summary_path),
    }


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Extract optimized meals from old ADK eval traces and build TensorBoard projector files."
    )
    parser.add_argument(
        "--trace-dir",
        action="append",
        default=[str(path) for path in DEFAULT_TRACE_DIRS],
        help="Trace directory to include. Pass multiple times to combine datasets.",
    )
    parser.add_argument(
        "--output-dir",
        default=str(DEFAULT_OUTPUT_DIR),
        help="Directory where projector files will be written.",
    )
    parser.add_argument(
        "--no-clean",
        action="store_true",
        help="Do not delete the output directory before writing fresh files.",
    )
    return parser


def main() -> int:
    parser = build_parser()
    args = parser.parse_args()
    trace_dirs = [Path(path).expanduser().resolve() for path in args.trace_dir]
    output_dir = Path(args.output_dir).expanduser().resolve()

    result = build_projector_dataset(trace_dirs, output_dir, clean=not args.no_clean)
    print(json.dumps(result, indent=2))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())