from __future__ import annotations

import json
import os
import re
import sys
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.tools.tool_context import ToolContext
from google.genai import types

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.calculate_mfp_macros import calculate_mfp_macros
from utils.ingredient_tool import find_ingredient
from utils.optimizer import _extract_products_from_history
from utils.optimizer import _filter_optimizer_targets
from utils.optimizer_mfp import _build_target_bundle_from_mfp_payload
from utils.optimizer_mfp import optimize_quantity_for_mfp_targets

from .heuristics import build_ingredient_queries
from .heuristics import build_search_space
from .heuristics import infer_plan_days
from .heuristics import normalize_cuisine_name


def _extract_latest_user_text(llm_request: LlmRequest) -> str:
    for content in reversed(llm_request.contents):
        if getattr(content, "role", "") != "user":
            continue

        for part in reversed(getattr(content, "parts", []) or []):
            text = getattr(part, "text", None)
            if isinstance(text, str) and text.strip():
                return text.strip()

    return ""


def _extract_function_responses_from_request(
    llm_request: LlmRequest,
    function_name: str,
) -> List[Dict[str, Any]]:
    responses: List[Dict[str, Any]] = []
    for content in llm_request.contents:
        for part in getattr(content, "parts", []) or []:
            function_response = getattr(part, "function_response", None)
            if not function_response:
                continue
            if getattr(function_response, "name", "") != function_name:
                continue
            payload = getattr(function_response, "response", None)
            if isinstance(payload, dict):
                responses.append(payload)
    return responses


def _extract_latest_function_response_from_request(
    llm_request: LlmRequest,
    function_name: str,
) -> Optional[Dict[str, Any]]:
    responses = _extract_function_responses_from_request(llm_request, function_name)
    if not responses:
        return None
    return responses[-1]


def _make_function_call_response(name: str, args: Dict[str, Any]) -> LlmResponse:
    return LlmResponse(
        content=types.Content(
            role="model",
            parts=[
                types.Part(
                    function_call=types.FunctionCall(
                        name=name,
                        args=args,
                    )
                )
            ],
        )
    )


def _make_text_event(text: str, role: str = "user", author: str = "user") -> Any:
    return SimpleNamespace(
        author=author,
        content=SimpleNamespace(role=role, parts=[SimpleNamespace(text=text)]),
    )


def _make_function_response_event(name: str, response: Any) -> Any:
    return SimpleNamespace(
        content=SimpleNamespace(
            parts=[SimpleNamespace(function_response=SimpleNamespace(name=name, response=response))]
        )
    )


def _build_tool_context(callback_context: CallbackContext, user_text: str) -> ToolContext:
    invocation_context = getattr(callback_context, "_invocation_context", None)
    if invocation_context is None:
        session = SimpleNamespace(events=[_make_text_event(user_text)])
        invocation_context = SimpleNamespace(session=session)
    else:
        session = getattr(invocation_context, "session", None)
        events = getattr(session, "events", None) if session is not None else None
        if isinstance(events, list) and user_text.strip():
            has_user_text = False
            for event in events:
                content = getattr(event, "content", None)
                if getattr(content, "role", "") != "user":
                    continue
                for part in getattr(content, "parts", []) or []:
                    text = getattr(part, "text", None)
                    if isinstance(text, str) and text.strip() == user_text.strip():
                        has_user_text = True
                        break
                if has_user_text:
                    break
            if not has_user_text:
                events.append(_make_text_event(user_text))

    return SimpleNamespace(_invocation_context=invocation_context)


def _append_tool_response(tool_context: ToolContext, name: str, response: Any) -> None:
    tool_context._invocation_context.session.events.append(_make_function_response_event(name, response))


def _parse_age(text: str) -> Tuple[float, str]:
    mixed_year_month_match = re.search(
        r"(\d+(?:\.\d+)?)\s*year[s]?\s*and\s*(\d+(?:\.\d+)?)\s*month[s]?\s*old",
        text,
        re.IGNORECASE,
    )
    if mixed_year_month_match:
        years = float(mixed_year_month_match.group(1))
        months = float(mixed_year_month_match.group(2))
        return (years * 12.0) + months, "months"

    month_match = re.search(r"(\d+(?:\.\d+)?)\s*[- ]?month[s]?[- ]old", text, re.IGNORECASE)
    if month_match:
        return float(month_match.group(1)), "months"

    year_match = re.search(r"(\d+(?:\.\d+)?)\s*[- ]?year[s]?[- ]old", text, re.IGNORECASE)
    if year_match:
        return float(year_match.group(1)), "years"

    raise ValueError("Could not parse age from user query")


def _parse_weight(text: str) -> Tuple[float, str]:
    match = re.search(r"I weigh\s+(\d+(?:\.\d+)?)\s*(kg|lb)\b", text, re.IGNORECASE)
    if not match:
        raise ValueError("Could not parse weight from user query")
    return float(match.group(1)), match.group(2).lower()


def _parse_height(text: str) -> Tuple[float, str, Optional[float]]:
    feet_inches_match = re.search(
        r"I(?:'m| am)\s+(\d+(?:\.\d+)?)\s*ft\s+(\d+(?:\.\d+)?)\s*in\s+tall",
        text,
        re.IGNORECASE,
    )
    if feet_inches_match:
        return float(feet_inches_match.group(1)), "ft_in", float(feet_inches_match.group(2))

    cm_match = re.search(r"I(?:'m| am)\s+(\d+(?:\.\d+)?)\s*cm\s+tall", text, re.IGNORECASE)
    if cm_match:
        return float(cm_match.group(1)), "cm", None

    raise ValueError("Could not parse height from user query")


def _parse_activity_level(text: str) -> str:
    match = re.search(r"activity level is\s+([A-Za-z ]+?)\.", text, re.IGNORECASE)
    if not match:
        raise ValueError("Could not parse activity level from user query")
    return match.group(1).strip().lower()


def _parse_goal(text: str) -> str:
    lowered = text.lower()
    if "lose weight" in lowered:
        return "lose weight"
    if "gain weight" in lowered:
        return "gain weight"
    if "maintain my weight" in lowered or "no planned weekly change" in lowered:
        return "maintain"
    raise ValueError("Could not parse goal from user query")


def _parse_weekly_rate_lbs(text: str, goal: str) -> float:
    if goal == "maintain" or "no planned weekly change" in text.lower():
        return 0.0

    match = re.search(r"at about\s+(\d+(?:\.\d+)?)\s*lb\s+per\s+week", text, re.IGNORECASE)
    if match:
        return float(match.group(1))

    return 0.5


def _parse_carb_fat_preference(text: str) -> str:
    match = re.search(
        r"(balanced|lower carb|higher carb|higher protein)\s+macro split",
        text,
        re.IGNORECASE,
    )
    if not match:
        return "balanced"
    return match.group(1).strip().lower()


def _parse_cuisine(text: str) -> str:
    match = re.search(r"want to eat\s+(.+?)\s+food", text, re.IGNORECASE)
    if not match:
        return "american"
    return normalize_cuisine_name(match.group(1).strip())


def _parse_sex(text: str) -> str:
    lowered = text.lower()
    if "girl" in lowered or "woman" in lowered or "female" in lowered:
        return "female"
    if "boy" in lowered or "man" in lowered or "male" in lowered:
        return "male"
    raise ValueError("Could not parse sex from user query")


def _parse_profile(text: str) -> Dict[str, Any]:
    age, age_unit = _parse_age(text)
    weight, weight_unit = _parse_weight(text)
    height, height_unit, height_inches = _parse_height(text)
    goal = _parse_goal(text)
    return {
        "sex": _parse_sex(text),
        "age": age,
        "age_unit": age_unit,
        "weight": weight,
        "weight_unit": weight_unit,
        "height": height,
        "height_unit": height_unit,
        "height_inches": height_inches,
        "activity_level": _parse_activity_level(text),
        "goal": goal,
        "weekly_rate_lbs": _parse_weekly_rate_lbs(text, goal),
        "carb_fat_preference": _parse_carb_fat_preference(text),
    }


def _results_by_label(ingredient_queries: List[Dict[str, Any]], ingredient_response: Any) -> Dict[str, List[Dict[str, Any]]]:
    grouped: Dict[str, List[Dict[str, Any]]] = {}
    if not isinstance(ingredient_response, list):
        return grouped
    for item in ingredient_response:
        if not isinstance(item, dict):
            continue
        query_index = item.get("query_index")
        if not isinstance(query_index, int):
            continue
        if query_index < 0 or query_index >= len(ingredient_queries):
            continue
        label = str(ingredient_queries[query_index].get("label", "")).strip().lower()
        if not label:
            continue
        results = item.get("results")
        if isinstance(results, list):
            grouped[label] = [result for result in results if isinstance(result, dict)]
    return grouped


def _format_number(value: Any, decimals: int = 2) -> str:
    try:
        number = float(value)
    except (TypeError, ValueError):
        return "N/A"

    rounded = round(number)
    if abs(number - rounded) < 1e-6:
        return str(int(rounded))
    return f"{number:.{decimals}f}".rstrip("0").rstrip(".")


def _natural_sort_key(value: str) -> List[Any]:
    return [int(part) if part.isdigit() else part.lower() for part in re.split(r"(\d+)", str(value))]


def _humanize_identifier(value: Any, fallback_prefix: str) -> str:
    text = str(value).strip()
    if not text:
        return fallback_prefix

    normalized = text.replace("::", " ").replace("_", " ")
    normalized = re.sub(r"\s+", " ", normalized).strip()
    if not normalized:
        return fallback_prefix

    return normalized.title()


def _display_product_name(product: Optional[Dict[str, Any]], index: Any) -> str:
    if isinstance(product, dict):
        name = str(product.get("name", "")).strip()
        if name:
            return name
    return f"Ingredient {index}"


def _format_target_value(key: str, unit: str, point_targets: Dict[str, float]) -> str:
    if key in point_targets:
        return f"{_format_number(point_targets[key])} {unit}"
    return "N/A"


def _format_actual_value(key: str, unit: str, achieved_targets: Dict[str, Any]) -> str:
    if key not in achieved_targets:
        return "N/A"
    return f"{_format_number(achieved_targets[key])} {unit}"


def _build_summary_table(point_targets: Dict[str, float], achieved_targets: Dict[str, Any]) -> str:
    rows = [
        ("Protein", "protein", "g"),
        ("Carbs", "carbohydrates", "g"),
        ("Fat", "total_fat", "g"),
        ("Calories", "calories", "kcal"),
        ("Fibre", "total_fibre", "g"),
    ]

    lines = [
        "| Parameter | Target | Actual |",
        "| --- | --- | --- |",
    ]
    for label, key, unit in rows:
        lines.append(
            f"| {label} | {_format_target_value(key, unit, point_targets)} | {_format_actual_value(key, unit, achieved_targets)} |"
        )
    return "\n".join(lines)


def _build_meal_lines(
    calculated_quantity_per_day: Dict[str, Any],
    products_by_index: Dict[int, Dict[str, Any]],
) -> List[str]:
    if not isinstance(calculated_quantity_per_day, dict) or not calculated_quantity_per_day:
        return []

    lines: List[str] = []
    for day_position, day_name in enumerate(sorted(calculated_quantity_per_day.keys(), key=_natural_sort_key), start=1):
        meals = calculated_quantity_per_day.get(day_name)
        if not isinstance(meals, dict):
            continue

        day_label = _humanize_identifier(day_name, f"Day {day_position}")
        if not day_label.lower().startswith("day"):
            day_label = f"Day {day_position}"
        lines.append(f"{day_label}:")

        for meal_position, meal_name in enumerate(sorted(meals.keys(), key=_natural_sort_key), start=1):
            selected_options = meals.get(meal_name)
            meal_label = _humanize_identifier(meal_name, f"Meal {meal_position}")
            if not meal_label.lower().startswith("meal"):
                meal_label = f"Meal {meal_position}"

            selected_items: List[Dict[str, Any]] = []
            if isinstance(selected_options, dict):
                for option_name in sorted(selected_options.keys(), key=_natural_sort_key):
                    option_items = selected_options.get(option_name)
                    if isinstance(option_items, list) and option_items:
                        selected_items = option_items
                        break

            ingredient_texts: List[str] = []
            for item in selected_items:
                if not isinstance(item, dict):
                    continue
                raw_index = item.get("index")
                raw_qty = item.get("qty")
                try:
                    index = int(raw_index)
                    qty = float(raw_qty)
                except (TypeError, ValueError):
                    continue

                ingredient_name = _display_product_name(products_by_index.get(index), index)
                ingredient_texts.append(f"{ingredient_name} ({_format_number(qty)} g)")

            joined_ingredients = ", ".join(ingredient_texts)
            lines.append(f"- {meal_label}: [{joined_ingredients}]")

        if lines and lines[-1] != "":
            lines.append("")

    while lines and not lines[-1].strip():
        lines.pop()
    return lines


def _build_adjustment_suggestions(
    point_targets: Dict[str, float],
    achieved_targets: Dict[str, Any],
    infeasibility_details: Dict[str, Any],
) -> List[str]:
    suggestions: List[str] = []

    calorie_difference = infeasibility_details.get("daily_calorie_difference")
    if isinstance(calorie_difference, (int, float)):
        if calorie_difference < 0:
            suggestions.append("Increase overall calories by raising portions of calorie-dense carb or fat sources.")
        elif calorie_difference > 0:
            suggestions.append("Reduce overall calories by lowering portions of the most energy-dense ingredients.")

    macro_adjustments = [
        ("protein", "Increase lean protein sources.", "Reduce the most protein-dense ingredients."),
        ("carbohydrates", "Increase carb portions such as grains, fruit, or starchy vegetables.", "Reduce carb-heavy ingredients."),
        ("total_fat", "Increase healthy fat sources such as nuts, seeds, avocado, or oils.", "Reduce added fats and the richest fat sources."),
    ]
    for macro_key, increase_text, decrease_text in macro_adjustments:
        actual = achieved_targets.get(macro_key)
        target = point_targets.get(macro_key)
        if not isinstance(actual, (int, float)) or not isinstance(target, (int, float)):
            continue
        if float(actual) < float(target) - 1e-6:
            suggestions.append(increase_text)
        elif float(actual) > float(target) + 1e-6:
            suggestions.append(decrease_text)

    deduped: List[str] = []
    for suggestion in suggestions:
        if suggestion not in deduped:
            deduped.append(suggestion)

    if not deduped:
        deduped.append("Widen ingredient quantity bounds or swap in ingredients with more compatible calorie and macro profiles.")
    return deduped


def _build_infeasibility_lines(
    point_targets: Dict[str, float],
    achieved_targets: Dict[str, Any],
    infeasibility_details: Dict[str, Any],
) -> List[str]:
    if not isinstance(infeasibility_details, dict) or not infeasibility_details:
        return []

    lines = ["Deviations:"]
    calorie_difference = infeasibility_details.get("daily_calorie_difference")
    if isinstance(calorie_difference, (int, float)):
        lines.append(f"- Daily calorie difference: {_format_number(calorie_difference)} kcal")

    deviation_labels = [
        ("daily_protein_difference", "Daily protein difference"),
        ("daily_carbohydrate_difference", "Daily carbohydrate difference"),
        ("daily_total_fat_difference", "Daily total_fat difference"),
    ]
    for key, label in deviation_labels:
        value = infeasibility_details.get(key)
        if isinstance(value, (int, float)):
            lines.append(f"- {label}: {_format_number(value)} g")

    lines.append("Adjustments:")
    for suggestion in _build_adjustment_suggestions(point_targets, achieved_targets, infeasibility_details):
        lines.append(f"- {suggestion}")
    return lines


def _format_final_response(
    optimizer_response: Dict[str, Any],
    mfp_response: Dict[str, Any],
    tool_context: ToolContext,
) -> str:
    point_targets = _filter_optimizer_targets(_build_target_bundle_from_mfp_payload(mfp_response))
    achieved_targets = optimizer_response.get("achieved_targets_per_day")
    if not isinstance(achieved_targets, dict):
        achieved_targets = {}

    lines = [
        "Summary:",
        _build_summary_table(point_targets, achieved_targets),
    ]

    status = str(optimizer_response.get("status", "")).strip().lower()
    if status == "infeasible":
        infeasibility_lines = _build_infeasibility_lines(
            point_targets,
            achieved_targets,
            optimizer_response.get("infeasibility_details", {}),
        )
        if infeasibility_lines:
            lines.extend(["", *infeasibility_lines])
    else:
        lines.extend(["", "- The plan meets the current calorie and macro targets."])

    products_by_index = _extract_products_from_history(tool_context)
    meal_lines = _build_meal_lines(optimizer_response.get("calculated_quantity_per_day", {}), products_by_index)
    if meal_lines:
        lines.extend(["", *meal_lines])

    return "\n".join(lines).strip()


def _call_mfp(profile: Dict[str, Any], tool_context: ToolContext) -> Dict[str, Any]:
    mfp_response = calculate_mfp_macros(
        sex=profile["sex"],
        age=profile["age"] if profile.get("age_unit") == "years" else float(profile["age"]) / 12.0,
        weight=profile["weight"],
        height=profile["height"],
        activity_level=profile["activity_level"],
        goal=profile["goal"],
        weekly_rate_lbs=profile["weekly_rate_lbs"],
        carb_fat_preference=profile["carb_fat_preference"],
        weight_unit=profile["weight_unit"],
        height_unit=profile["height_unit"],
        height_inches=profile["height_inches"],
    )
    if not isinstance(mfp_response, dict):
        raise ValueError(str(mfp_response))
    _append_tool_response(tool_context, "calculate_mfp_macros", mfp_response)
    return mfp_response


def _call_find_ingredient(cuisine: str, tool_context: ToolContext) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
    ingredient_queries = build_ingredient_queries(cuisine)
    tool_queries = [{k: v for k, v in query.items() if k != "label"} for query in ingredient_queries]
    ingredient_response = find_ingredient(ingredient_queries=tool_queries, tool_context=tool_context)
    if not isinstance(ingredient_response, list):
        raise ValueError(str(ingredient_response))
    wrapped_response = {"result": ingredient_response}
    _append_tool_response(tool_context, "find_ingredient", wrapped_response)
    return ingredient_queries, _results_by_label(ingredient_queries, ingredient_response)


async def _run_optimizer_with_retries(
    profile: Dict[str, Any],
    cuisine: str,
    number_of_days: int,
    results_by_label: Dict[str, List[Dict[str, Any]]],
    tool_context: ToolContext,
) -> Dict[str, Any]:
    last_response: Dict[str, Any] = {}
    for attempt in range(3):
        search_space = build_search_space(
            profile=profile,
            cuisine=cuisine,
            number_of_days=number_of_days,
            results_by_label=results_by_label,
            attempt=attempt,
        )
        if not search_space or not any(search_space.values()):
            raise ValueError("Could not build a valid search space from ingredient heuristics")

        optimizer_response = await optimize_quantity_for_mfp_targets(
            meals_json=json.dumps(search_space),
            number_of_days=number_of_days,
            tool_context=tool_context,
        )
        if not isinstance(optimizer_response, dict):
            raise ValueError(str(optimizer_response))

        _append_tool_response(tool_context, "optimize_quantity_for_mfp_targets", optimizer_response)
        last_response = optimizer_response
        if str(optimizer_response.get("status", "")).strip().lower() == "feasible":
            return optimizer_response

    return last_response


async def _run_pipeline(callback_context: CallbackContext, user_text: str) -> str:
    profile = _parse_profile(user_text)
    cuisine = _parse_cuisine(user_text)
    number_of_days = infer_plan_days(user_text)
    tool_context = _build_tool_context(callback_context, user_text)

    mfp_response = _call_mfp(profile, tool_context)
    _, results_by_label = _call_find_ingredient(cuisine, tool_context)
    optimizer_response = await _run_optimizer_with_retries(
        profile=profile,
        cuisine=cuisine,
        number_of_days=number_of_days,
        results_by_label=results_by_label,
        tool_context=tool_context,
    )
    return _format_final_response(optimizer_response, mfp_response, tool_context)


async def _deterministic_before_model_callback(
    callback_context: CallbackContext,
    llm_request: LlmRequest,
) -> Optional[LlmResponse]:
    user_text = _extract_latest_user_text(llm_request)
    if not user_text:
        return LlmResponse(
            content=types.Content(role="model", parts=[types.Part(text="No user input found.")])
        )

    try:
        profile = _parse_profile(user_text)
        cuisine = _parse_cuisine(user_text)
        number_of_days = infer_plan_days(user_text)

        mfp_response = _extract_latest_function_response_from_request(
            llm_request,
            "calculate_mfp_macros",
        )
        if not isinstance(mfp_response, dict):
            return _make_function_call_response(
                "calculate_mfp_macros",
                {
                    "sex": profile["sex"],
                    "age": profile["age"] if profile.get("age_unit") == "years" else float(profile["age"]) / 12.0,
                    "weight": profile["weight"],
                    "height": profile["height"],
                    "activity_level": profile["activity_level"],
                    "goal": profile["goal"],
                    "weekly_rate_lbs": profile["weekly_rate_lbs"],
                    "carb_fat_preference": profile["carb_fat_preference"],
                    "weight_unit": profile["weight_unit"],
                    "height_unit": profile["height_unit"],
                    "height_inches": profile["height_inches"],
                },
            )

        find_response_payload = _extract_latest_function_response_from_request(
            llm_request,
            "find_ingredient",
        )
        ingredient_queries = build_ingredient_queries(cuisine)
        tool_queries = [{k: v for k, v in query.items() if k != "label"} for query in ingredient_queries]
        if not isinstance(find_response_payload, dict):
            return _make_function_call_response(
                "find_ingredient",
                {"ingredient_queries": tool_queries},
            )

        raw_find_result = find_response_payload.get("result")
        if not isinstance(raw_find_result, list):
            raw_find_result = []
        results_by_label = _results_by_label(ingredient_queries, raw_find_result)
        if not results_by_label:
            return LlmResponse(
                content=types.Content(
                    role="model",
                    parts=[types.Part(text="Error: Could not build ingredient results from find_ingredient output.")],
                )
            )

        optimize_responses = _extract_function_responses_from_request(llm_request, "optimize_quantity_for_mfp_targets")
        optimize_response = optimize_responses[-1] if optimize_responses else None

        optimize_call_count = len(optimize_responses)
        optimize_status = str((optimize_response or {}).get("status", "")).strip().lower() if isinstance(optimize_response, dict) else ""
        if not isinstance(optimize_response, dict) or (optimize_status != "feasible" and optimize_call_count < 3):
            attempt = min(optimize_call_count, 2)
            search_space = build_search_space(
                profile=profile,
                cuisine=cuisine,
                number_of_days=number_of_days,
                results_by_label=results_by_label,
                attempt=attempt,
            )
            if not search_space or not any(search_space.values()):
                raise ValueError("Could not build a valid search space from ingredient heuristics")
            return _make_function_call_response(
                "optimize_quantity_for_mfp_targets",
                {
                    "meals_json": json.dumps(search_space),
                    "number_of_days": number_of_days,
                },
            )

        if not isinstance(optimize_response, dict):
            raise ValueError("No optimize_quantity_for_mfp_targets response available")

        tool_context = _build_tool_context(callback_context, user_text)
        final_text = _format_final_response(optimize_response, mfp_response, tool_context)
    except Exception as exc:
        final_text = f"Error: {exc}"

    return LlmResponse(
        content=types.Content(
            role="model",
            parts=[types.Part(text=final_text.strip())],
        )
    )


agent = Agent(
    model=LiteLlm(
        model="openrouter/google/gemini-3-flash-preview",
        api_key=os.getenv("OPENROUTER_API_KEY"),
        api_base="https://openrouter.ai/api/v1",
    ),
    name="NoLLMMFPNutritionPlanner",
    description="Deterministic non-LLM MFP planner that uses ingredient search, macro target calculation, and the MFP optimizer.",
    instruction="Parse the user request, call the deterministic MFP planning pipeline, and return the final formatted plan.",
    tools=[find_ingredient, calculate_mfp_macros, optimize_quantity_for_mfp_targets],
    before_model_callback=_deterministic_before_model_callback,
)

root_agent = agent
