from __future__ import annotations

import json
import os
import re
import sys
from types import SimpleNamespace
from typing import Any, Dict, List, Optional, Tuple

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.tools.tool_context import ToolContext
from google.genai import types

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from utils.calculate_health_canada_dri import calculate_health_canada_dri
from utils.ingredient_tool import find_ingredient
from utils.optimizer import _build_target_bundle_from_dri_payload
from utils.optimizer import _extract_products_from_history
from utils.optimizer import _filter_optimizer_target_ranges
from utils.optimizer import _filter_optimizer_targets
from utils.optimizer import optimize_quantity

from .heuristics import build_ingredient_queries
from .heuristics import build_search_space
from .heuristics import infer_plan_days
from .heuristics import normalize_cuisine_name


def _extract_latest_user_text(llm_request: LlmRequest) -> str:
    for content in reversed(llm_request.contents):
        if getattr(content, "role", "") != "user":
            continue

        for part in reversed(getattr(content, "parts", []) or []):
            text = getattr(part, "text", None)
            if isinstance(text, str) and text.strip():
                return text.strip()

    return ""


def _extract_function_responses_from_request(
    llm_request: LlmRequest,
    function_name: str,
) -> List[Dict[str, Any]]:
    responses: List[Dict[str, Any]] = []
    for content in llm_request.contents:
        for part in getattr(content, "parts", []) or []:
            function_response = getattr(part, "function_response", None)
            if not function_response:
                continue
            if getattr(function_response, "name", "") != function_name:
                continue
            payload = getattr(function_response, "response", None)
            if isinstance(payload, dict):
                responses.append(payload)
    return responses


def _extract_latest_function_response_from_request(
    llm_request: LlmRequest,
    function_name: str,
) -> Optional[Dict[str, Any]]:
    responses = _extract_function_responses_from_request(llm_request, function_name)
    if not responses:
        return None
    return responses[-1]


def _make_function_call_response(name: str, args: Dict[str, Any]) -> LlmResponse:
    return LlmResponse(
        content=types.Content(
            role="model",
            parts=[
                types.Part(
                    function_call=types.FunctionCall(
                        name=name,
                        args=args,
                    )
                )
            ],
        )
    )


def _make_text_event(text: str, role: str = "user", author: str = "user") -> Any:
    return SimpleNamespace(
        author=author,
        content=SimpleNamespace(role=role, parts=[SimpleNamespace(text=text)]),
    )


def _make_function_response_event(name: str, response: Any) -> Any:
    return SimpleNamespace(
        content=SimpleNamespace(
            parts=[SimpleNamespace(function_response=SimpleNamespace(name=name, response=response))]
        )
    )


def _build_tool_context(callback_context: CallbackContext, user_text: str) -> ToolContext:
    invocation_context = getattr(callback_context, "_invocation_context", None)
    if invocation_context is None:
        session = SimpleNamespace(events=[_make_text_event(user_text)])
        invocation_context = SimpleNamespace(session=session)
    else:
        session = getattr(invocation_context, "session", None)
        events = getattr(session, "events", None) if session is not None else None
        if isinstance(events, list) and user_text.strip():
            has_user_text = False
            for event in events:
                content = getattr(event, "content", None)
                if getattr(content, "role", "") != "user":
                    continue
                for part in getattr(content, "parts", []) or []:
                    text = getattr(part, "text", None)
                    if isinstance(text, str) and text.strip() == user_text.strip():
                        has_user_text = True
                        break
                if has_user_text:
                    break
            if not has_user_text:
                events.append(_make_text_event(user_text))

    return SimpleNamespace(_invocation_context=invocation_context)


def _append_tool_response(tool_context: ToolContext, name: str, response: Any) -> None:
    tool_context._invocation_context.session.events.append(_make_function_response_event(name, response))


def _try_float(text: str) -> Optional[float]:
    try:
        return float(text)
    except (TypeError, ValueError):
        return None


def _parse_age(text: str) -> Tuple[float, str]:
    mixed_year_month_match = re.search(
        r"(\d+(?:\.\d+)?)\s*year[s]?\s*and\s*(\d+(?:\.\d+)?)\s*month[s]?\s*old",
        text,
        re.IGNORECASE,
    )
    if mixed_year_month_match:
        years = float(mixed_year_month_match.group(1))
        months = float(mixed_year_month_match.group(2))
        return (years * 12.0) + months, "months"

    month_match = re.search(r"(\d+(?:\.\d+)?)\s*[- ]?month[s]?[- ]old", text, re.IGNORECASE)
    if month_match:
        return float(month_match.group(1)), "months"

    year_match = re.search(r"(\d+(?:\.\d+)?)\s*[- ]?year[s]?[- ]old", text, re.IGNORECASE)
    if year_match:
        return float(year_match.group(1)), "years"

    raise ValueError("Could not parse age from user query")


def _parse_weight(text: str) -> Tuple[float, str]:
    match = re.search(r"I weigh\s+(\d+(?:\.\d+)?)\s*(kg|lb)\b", text, re.IGNORECASE)
    if not match:
        raise ValueError("Could not parse weight from user query")
    return float(match.group(1)), match.group(2).lower()


def _parse_height(text: str) -> Tuple[float, str, Optional[float]]:
    feet_inches_match = re.search(
        r"I(?:'m| am)\s+(\d+(?:\.\d+)?)\s*ft\s+(\d+(?:\.\d+)?)\s*in\s+tall",
        text,
        re.IGNORECASE,
    )
    if feet_inches_match:
        return float(feet_inches_match.group(1)), "ft_in", float(feet_inches_match.group(2))

    cm_match = re.search(r"I(?:'m| am)\s+(\d+(?:\.\d+)?)\s*cm\s+tall", text, re.IGNORECASE)
    if cm_match:
        return float(cm_match.group(1)), "cm", None

    raise ValueError("Could not parse height from user query")


def _parse_activity_level(text: str) -> str:
    match = re.search(r"activity level is\s+([A-Za-z ]+?)\.", text, re.IGNORECASE)
    if not match:
        raise ValueError("Could not parse activity level from user query")
    return match.group(1).strip().title()


def _parse_cuisine(text: str) -> str:
    match = re.search(r"want to eat\s+(.+?)\s+food", text, re.IGNORECASE)
    if not match:
        return "american"
    return normalize_cuisine_name(match.group(1).strip())


def _parse_sex(text: str) -> str:
    lowered = text.lower()
    if "girl" in lowered or "woman" in lowered or "female" in lowered:
        return "Female"
    if "boy" in lowered or "man" in lowered or "male" in lowered:
        return "Male"
    raise ValueError("Could not parse sex from user query")


def _parse_profile(text: str) -> Dict[str, Any]:
    age, age_unit = _parse_age(text)
    weight, weight_unit = _parse_weight(text)
    height, height_unit, height_inches = _parse_height(text)
    return {
        "sex": _parse_sex(text),
        "age": age,
        "age_unit": age_unit,
        "weight": weight,
        "weight_unit": weight_unit,
        "height": height,
        "height_unit": height_unit,
        "height_inches": height_inches,
        "activity_level": _parse_activity_level(text),
        "is_pregnant": False,
        "lactation_status": "none",
    }


def _results_by_label(ingredient_queries: List[Dict[str, Any]], ingredient_response: Any) -> Dict[str, List[Dict[str, Any]]]:
    grouped: Dict[str, List[Dict[str, Any]]] = {}
    if not isinstance(ingredient_response, list):
        return grouped
    for item in ingredient_response:
        if not isinstance(item, dict):
            continue
        query_index = item.get("query_index")
        if not isinstance(query_index, int):
            continue
        if query_index < 0 or query_index >= len(ingredient_queries):
            continue
        label = str(ingredient_queries[query_index].get("label", "")).strip().lower()
        if not label:
            continue
        results = item.get("results")
        if isinstance(results, list):
            grouped[label] = [result for result in results if isinstance(result, dict)]
    return grouped


def _format_number(value: Any, decimals: int = 2) -> str:
    try:
        number = float(value)
    except (TypeError, ValueError):
        return "N/A"

    rounded = round(number)
    if abs(number - rounded) < 1e-6:
        return str(int(rounded))
    return f"{number:.{decimals}f}".rstrip("0").rstrip(".")


def _build_summary_table(
    point_targets: Dict[str, float],
    range_targets: Dict[str, Dict[str, float]],
    achieved_targets: Dict[str, Any],
) -> str:
    rows = [
        ("Protein", "protein", "g"),
        ("Carbs", "carbohydrates", "g"),
        ("Fat", "total_fat", "g"),
        ("Calories", "calories", "kcal"),
        ("Fibre", "total_fibre", "g"),
    ]
    lines = ["| Parameter | Target | Actual |", "| --- | --- | --- |"]
    for label, key, unit in rows:
        if key in range_targets:
            target = f"{_format_number(range_targets[key].get('lower'))}-{_format_number(range_targets[key].get('upper'))} {unit}"
        elif key in point_targets:
            target = f"{_format_number(point_targets[key])} {unit}"
        else:
            target = "N/A"
        actual = f"{_format_number(achieved_targets.get(key))} {unit}" if key in achieved_targets else "N/A"
        lines.append(f"| {label} | {target} | {actual} |")
    return "\n".join(lines)


def _build_meal_lines(calculated_quantity_per_day: Dict[str, Any], tool_context: Any) -> List[str]:
    products_by_index = _extract_products_from_history(tool_context)
    lines: List[str] = []
    for day_name in sorted(calculated_quantity_per_day.keys()):
        day_payload = calculated_quantity_per_day.get(day_name)
        if not isinstance(day_payload, dict):
            continue
        lines.append(day_name.replace("_", " ").title() + ":")
        for meal_name in sorted(day_payload.keys()):
            option_payload = day_payload.get(meal_name)
            items: List[Dict[str, Any]] = []
            if isinstance(option_payload, dict):
                for option_name in sorted(option_payload.keys()):
                    candidate_items = option_payload.get(option_name)
                    if isinstance(candidate_items, list) and candidate_items:
                        items = candidate_items
                        break
            rendered: List[str] = []
            for item in items:
                if not isinstance(item, dict):
                    continue
                try:
                    index = int(item.get("index"))
                    qty = float(item.get("qty"))
                except (TypeError, ValueError):
                    continue
                product = products_by_index.get(index, {})
                name = str(product.get("name", f"Ingredient {index}")).strip() or f"Ingredient {index}"
                rendered.append(f"{name} ({_format_number(qty)} g)")
            lines.append(f"- {meal_name.replace('_', ' ').title()}: [{', '.join(rendered)}]")

        # Keep day sections visually separated in markdown/plain-text rendering.
        lines.append("")

    while lines and not lines[-1].strip():
        lines.pop()
    return lines


def _format_final_response(
    optimizer_response: Dict[str, Any],
    dri_response: Dict[str, Any],
    tool_context: Any,
) -> str:
    point_targets_raw, range_targets_raw = _build_target_bundle_from_dri_payload(dri_response)
    point_targets = _filter_optimizer_targets(point_targets_raw)
    range_targets = _filter_optimizer_target_ranges(range_targets_raw)
    achieved_targets = optimizer_response.get("achieved_targets_per_day")
    if not isinstance(achieved_targets, dict):
        achieved_targets = {}

    lines = [
        "Summary:",
        _build_summary_table(point_targets, range_targets, achieved_targets),
        "",
    ]

    if str(optimizer_response.get("status", "")).strip().lower() == "infeasible":
        lines.append(str(optimizer_response.get("message", "Optimization remained infeasible after 3 attempts.")).strip())
        infeasibility = optimizer_response.get("infeasibility_details")
        if isinstance(infeasibility, dict) and infeasibility:
            lines.append("")
            lines.append("Deviations:")
            for key in sorted(infeasibility.keys()):
                lines.append(f"- {key}: {_format_number(infeasibility[key])}")
    else:
        lines.append("The plan meets the current calorie, fibre, and macro targets.")

    meal_lines = _build_meal_lines(optimizer_response.get("calculated_quantity_per_day", {}), tool_context)
    if meal_lines:
        lines.extend(["", *meal_lines])

    return "\n".join(line for line in lines if line is not None).strip()


def _call_dri(profile: Dict[str, Any], tool_context: ToolContext) -> Dict[str, Any]:
    dri_response = calculate_health_canada_dri(tool_context=tool_context, **profile)
    if not isinstance(dri_response, dict):
        raise ValueError(str(dri_response))
    _append_tool_response(tool_context, "calculate_health_canada_dri", dri_response)
    return dri_response


def _call_find_ingredient(cuisine: str, tool_context: ToolContext) -> Tuple[List[Dict[str, Any]], Dict[str, List[Dict[str, Any]]]]:
    ingredient_queries = build_ingredient_queries(cuisine)
    tool_queries = [{k: v for k, v in query.items() if k != "label"} for query in ingredient_queries]
    ingredient_response = find_ingredient(ingredient_queries=tool_queries, tool_context=tool_context)
    if not isinstance(ingredient_response, list):
        raise ValueError(str(ingredient_response))
    wrapped_response = {"result": ingredient_response}
    _append_tool_response(tool_context, "find_ingredient", wrapped_response)
    return ingredient_queries, _results_by_label(ingredient_queries, ingredient_response)


async def _run_optimizer_with_retries(
    profile: Dict[str, Any],
    cuisine: str,
    number_of_days: int,
    results_by_label: Dict[str, List[Dict[str, Any]]],
    tool_context: ToolContext,
) -> Dict[str, Any]:
    last_response: Dict[str, Any] = {}
    for attempt in range(3):
        search_space = build_search_space(
            profile=profile,
            cuisine=cuisine,
            number_of_days=number_of_days,
            results_by_label=results_by_label,
            attempt=attempt,
        )
        if not search_space or not any(search_space.values()):
            raise ValueError("Could not build a valid search space from ingredient heuristics")

        optimizer_response = await optimize_quantity(
            meals_json=json.dumps(search_space),
            number_of_days=number_of_days,
            tool_context=tool_context,
        )
        if not isinstance(optimizer_response, dict):
            raise ValueError(str(optimizer_response))

        _append_tool_response(tool_context, "optimize_quantity", optimizer_response)
        last_response = optimizer_response
        if str(optimizer_response.get("status", "")).strip().lower() == "feasible":
            return optimizer_response

    return last_response


async def _run_pipeline(callback_context: CallbackContext, user_text: str) -> str:
    profile = _parse_profile(user_text)
    cuisine = _parse_cuisine(user_text)
    number_of_days = infer_plan_days(user_text)
    tool_context = _build_tool_context(callback_context, user_text)

    dri_response = _call_dri(profile, tool_context)
    _, results_by_label = _call_find_ingredient(cuisine, tool_context)
    optimizer_response = await _run_optimizer_with_retries(
        profile=profile,
        cuisine=cuisine,
        number_of_days=number_of_days,
        results_by_label=results_by_label,
        tool_context=tool_context,
    )
    return _format_final_response(optimizer_response, dri_response, tool_context)


async def _deterministic_before_model_callback(
    callback_context: CallbackContext,
    llm_request: LlmRequest,
) -> Optional[LlmResponse]:
    user_text = _extract_latest_user_text(llm_request)
    if not user_text:
        return LlmResponse(
            content=types.Content(role="model", parts=[types.Part(text="No user input found.")])
        )

    try:
        profile = _parse_profile(user_text)
        cuisine = _parse_cuisine(user_text)
        number_of_days = infer_plan_days(user_text)

        dri_response = _extract_latest_function_response_from_request(
            llm_request,
            "calculate_health_canada_dri",
        )
        if not isinstance(dri_response, dict):
            return _make_function_call_response("calculate_health_canada_dri", profile)

        find_response_payload = _extract_latest_function_response_from_request(
            llm_request,
            "find_ingredient",
        )
        ingredient_queries = build_ingredient_queries(cuisine)
        tool_queries = [{k: v for k, v in query.items() if k != "label"} for query in ingredient_queries]
        if not isinstance(find_response_payload, dict):
            return _make_function_call_response(
                "find_ingredient",
                {"ingredient_queries": tool_queries},
            )

        raw_find_result = find_response_payload.get("result")
        if not isinstance(raw_find_result, list):
            raw_find_result = []
        results_by_label = _results_by_label(ingredient_queries, raw_find_result)
        if not results_by_label:
            return LlmResponse(
                content=types.Content(
                    role="model",
                    parts=[types.Part(text="Error: Could not build ingredient results from find_ingredient output.")],
                )
            )

        optimize_responses = _extract_function_responses_from_request(llm_request, "optimize_quantity")
        optimize_response = optimize_responses[-1] if optimize_responses else None

        optimize_call_count = len(optimize_responses)
        optimize_status = str((optimize_response or {}).get("status", "")).strip().lower() if isinstance(optimize_response, dict) else ""
        if not isinstance(optimize_response, dict) or (optimize_status != "feasible" and optimize_call_count < 3):
            attempt = min(optimize_call_count, 2)
            search_space = build_search_space(
                profile=profile,
                cuisine=cuisine,
                number_of_days=number_of_days,
                results_by_label=results_by_label,
                attempt=attempt,
            )
            if not search_space or not any(search_space.values()):
                raise ValueError("Could not build a valid search space from ingredient heuristics")
            return _make_function_call_response(
                "optimize_quantity",
                {
                    "meals_json": json.dumps(search_space),
                    "number_of_days": number_of_days,
                },
            )

        if not isinstance(optimize_response, dict):
            raise ValueError("No optimize_quantity response available")

        tool_context = _build_tool_context(callback_context, user_text)
        final_text = _format_final_response(optimize_response, dri_response, tool_context)
    except Exception as exc:
        final_text = f"Error: {exc}"

    return LlmResponse(
        content=types.Content(
            role="model",
            parts=[types.Part(text=final_text.strip())],
        )
    )


agent = Agent(
    model=LiteLlm(
        model="openrouter/google/gemini-3-flash-preview",
        api_key=os.getenv("OPENROUTER_API_KEY"),
        api_base="https://openrouter.ai/api/v1",
    ),
    name="NoLLMCanadaDRIPlanner",
    description="Deterministic non-LLM DRI planner that uses the existing tools and heuristic search space generation.",
    instruction="Parse the user request, call the deterministic planning pipeline, and return the final formatted plan.",
    tools=[find_ingredient, calculate_health_canada_dri, optimize_quantity],
    before_model_callback=_deterministic_before_model_callback,
)

root_agent = agent
