import sys
import os
import re
from typing import Any, Dict, Optional
from pathlib import Path

from google.genai import types
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.models.lite_llm import LiteLlm
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
from utils.ingredient_tool import find_ingredient
from utils.calculate_mfp_macros import calculate_mfp_macros
from utils.calculator import calculate_average_macro_nutrient_per_day
from utils.optimizer import _extract_latest_function_response_from_history
from utils.optimizer import _extract_products_from_history
from utils.optimizer_mfp import _build_target_bundle_from_mfp_payload


PROMPT_PATH = Path(__file__).with_name("mfp_prompt.txt")


def _format_number(value: Any, decimals: int = 2) -> str:
    try:
        number = float(value)
    except (TypeError, ValueError):
        return "N/A"

    rounded = round(number)
    if abs(number - rounded) < 1e-6:
        return str(int(rounded))
    return f"{number:.{decimals}f}".rstrip("0").rstrip(".")


def _natural_sort_key(value: str) -> list[Any]:
    return [int(part) if part.isdigit() else part.lower() for part in re.split(r"(\d+)", str(value))]


def _humanize_identifier(value: Any, fallback_prefix: str) -> str:
    text = str(value).strip()
    if not text:
        return fallback_prefix

    normalized = text.replace("::", " ").replace("_", " ")
    normalized = re.sub(r"\s+", " ", normalized).strip()
    if not normalized:
        return fallback_prefix

    return normalized.title()


def _display_product_name(product: Optional[Dict[str, Any]], index: Any) -> str:
    if isinstance(product, dict):
        name = str(product.get("name", "")).strip()
        if name:
            return name
    return f"Ingredient {index}"


def _normalize_position_label(value: Any, prefix: str, position: int) -> str:
    label = _humanize_identifier(value, f"{prefix} {position}")
    if not label.lower().startswith(prefix.lower()):
        return f"{prefix} {position}"
    return label


def _first_option_items(selected_options: Any) -> list[Dict[str, Any]]:
    if not isinstance(selected_options, dict):
        return []

    for option_name in sorted(selected_options.keys(), key=_natural_sort_key):
        option_data = selected_options.get(option_name)

        if isinstance(option_data, dict):
            items = option_data.get("items")
            if isinstance(items, list) and items:
                return items
        elif isinstance(option_data, list) and option_data:
            return option_data

    return []


def _first_numeric_value(values: Dict[str, Any], keys: list[str]) -> Optional[float]:
    for key in keys:
        candidate = values.get(key)
        if isinstance(candidate, (int, float)):
            return float(candidate)
    return None


def _format_target_value(
    key: str,
    unit: str,
    point_targets: Dict[str, float],
) -> str:
    if key in point_targets:
        return f"{_format_number(point_targets[key])} {unit}"
    return "N/A"


def _format_actual_value(key: str, unit: str, achieved_targets: Dict[str, Any]) -> str:
    if key not in achieved_targets:
        return "N/A"
    return f"{_format_number(achieved_targets[key])} {unit}"


def _build_summary_table(
    point_targets: Dict[str, float],
    achieved_targets: Dict[str, Any],
) -> str:
    rows = [
        ("Protein", "protein", "g"),
        ("Carbs", "carbohydrates", "g"),
        ("Fat", "total_fat", "g"),
        ("Calories", "calories", "kcal"),
        ("Fibre", "total_fibre", "g"),
    ]

    lines = [
        "| Parameter | Target | Actual |",
        "| --- | --- | --- |",
    ]
    for label, key, unit in rows:
        lines.append(
            f"| {label} | {_format_target_value(key, unit, point_targets)} | {_format_actual_value(key, unit, achieved_targets)} |"
        )
    return "\n".join(lines)


def _build_meal_lines(
    calculated_quantity_per_day: Dict[str, Any],
    products_by_index: Dict[int, Dict[str, Any]],
) -> list[str]:
    if not isinstance(calculated_quantity_per_day, dict) or not calculated_quantity_per_day:
        return []

    lines: list[str] = []
    for day_position, day_name in enumerate(sorted(calculated_quantity_per_day.keys(), key=_natural_sort_key), start=1):
        day_value = calculated_quantity_per_day.get(day_name)

        day_label = _normalize_position_label(day_name, "Day", day_position)
        lines.append(f"{day_label}:")

        if isinstance(day_value, list):
            meal_groups: Dict[str, list[Dict[str, Any]]] = {}
            for item in day_value:
                if not isinstance(item, dict):
                    continue
                meal_name = item.get("meal_name", "Unspecified Meal")
                if meal_name not in meal_groups:
                    meal_groups[meal_name] = []
                meal_groups[meal_name].append(item)
        elif isinstance(day_value, dict):
            meal_groups = {}
            for meal_name, meal_data in day_value.items():
                if isinstance(meal_data, list):
                    meal_groups[meal_name] = meal_data
                elif isinstance(meal_data, dict):
                    items = _first_option_items(meal_data)
                    if items:
                        meal_groups[meal_name] = items
        else:
            continue

        for meal_position, (meal_name, selected_items) in enumerate(sorted(meal_groups.items(), key=lambda kv: _natural_sort_key(str(kv[0]))), start=1):
            if not isinstance(selected_items, list):
                continue

            meal_label = _normalize_position_label(meal_name, "Meal", meal_position)
            ingredient_texts: list[str] = []

            for item in selected_items:
                if not isinstance(item, dict):
                    continue
                raw_index = item.get("index")
                raw_qty = item.get("qty") or item.get("quantity")
                try:
                    index = int(raw_index)
                    qty = float(raw_qty)
                except (TypeError, ValueError):
                    continue

                ingredient_name = _display_product_name(products_by_index.get(index), index)
                ingredient_texts.append(f"{ingredient_name} ({_format_number(qty)} g)")

            joined_ingredients = ", ".join(ingredient_texts)
            if str(meal_name).strip().lower() != meal_label.lower():
                lines.append(f"- {meal_label} ({meal_name}): [{joined_ingredients}]")
            else:
                lines.append(f"- {meal_label}: [{joined_ingredients}]")

        if lines and lines[-1] != "":
            lines.append("")

    while lines and not lines[-1].strip():
        lines.pop()
    return lines


def _build_adjustment_suggestions(
    point_targets: Dict[str, float],
    achieved_targets: Dict[str, Any],
    infeasibility_details: Dict[str, Any],
) -> list[str]:
    suggestions: list[str] = []

    calorie_difference = _first_numeric_value(
        infeasibility_details,
        ["average_daily_calorie_difference", "calorie_difference", "daily_calorie_difference"],
    )
    if calorie_difference is not None:
        if calorie_difference < 0:
            suggestions.append("Increase overall calories by raising portions of calorie-dense carb or fat sources.")
        elif calorie_difference > 0:
            suggestions.append("Reduce overall calories by lowering portions of the most energy-dense ingredients.")

    macro_adjustments = [
        ("protein", "Increase lean protein sources.", "Reduce the most protein-dense ingredients."),
        ("carbohydrates", "Increase carb portions such as grains, fruit, or starchy vegetables.", "Reduce carb-heavy ingredients."),
        ("total_fat", "Increase healthy fat sources such as nuts, seeds, avocado, or oils.", "Reduce added fats and the richest fat sources."),
    ]
    for macro_key, increase_text, decrease_text in macro_adjustments:
        actual = achieved_targets.get(macro_key)
        target = point_targets.get(macro_key)
        if not isinstance(actual, (int, float)) or not isinstance(target, (int, float)):
            continue
        if float(actual) < float(target) - 1e-6:
            suggestions.append(increase_text)
        elif float(actual) > float(target) + 1e-6:
            suggestions.append(decrease_text)

    deduped = list(dict.fromkeys(suggestions))
    if not deduped:
        deduped.append("Widen ingredient quantity bounds or swap in ingredients with more compatible calorie and macro profiles.")
    return deduped


def _build_infeasibility_lines(
    point_targets: Dict[str, float],
    achieved_targets: Dict[str, Any],
    infeasibility_details: Dict[str, Any],
) -> list[str]:
    if not isinstance(infeasibility_details, dict) or not infeasibility_details:
        return []

    lines = ["Deviations:"]

    calorie_difference = _first_numeric_value(
        infeasibility_details,
        ["average_daily_calorie_difference", "calorie_difference", "daily_calorie_difference"],
    )
    if calorie_difference is not None:
        lines.append(f"- Daily calorie difference: {_format_number(calorie_difference)} kcal")

    deviation_labels = [
        (["daily_protein_difference", "protein_difference"], "Daily protein difference"),
        (["daily_carbohydrate_difference", "daily_carbohydrates_difference", "carbohydrates_difference"], "Daily carbohydrate difference"),
        (["daily_total_fat_difference", "total_fat_difference", "fat_difference"], "Daily total_fat difference"),
    ]
    for keys, label in deviation_labels:
        value = _first_numeric_value(infeasibility_details, keys)
        if value is not None:
            lines.append(f"- {label}: {_format_number(value)} g")

    lines.append("Adjustments:")
    for suggestion in _build_adjustment_suggestions(point_targets, achieved_targets, infeasibility_details):
        lines.append(f"- {suggestion}")
    return lines


def _extract_best_calculator_response_and_run_count(
    llm_request: LlmRequest,
) -> tuple[Optional[Dict[str, Any]], int]:
    best_response = None
    best_key: Optional[tuple[int, float]] = None
    run_count = 0

    for content in llm_request.contents:
        parts = list(getattr(content, "parts", []) or [])

        for part in parts:
            function_response = getattr(part, "function_response", None)
            if not function_response:
                continue

            if getattr(function_response, "name", "") == "calculate_average_macro_nutrient_per_day":
                run_count += 1
                payload = getattr(function_response, "response", None)
                if not isinstance(payload, dict):
                    continue

                is_feasible = bool(payload.get("is_feasible", False))
                score_value = payload.get("feasibility_score")
                if isinstance(score_value, (int, float)):
                    score = float(score_value)
                else:
                    score = 0.0 if is_feasible else float("inf")

                candidate_key = (0 if is_feasible else 1, score)
                if best_key is None or candidate_key <= best_key:
                    best_key = candidate_key
                    best_response = payload

    return best_response, run_count


def _extract_latest_calculator_call_args(
    llm_request: LlmRequest,
) -> Optional[Dict[str, Any]]:
    for content in reversed(llm_request.contents):
        parts = list(getattr(content, "parts", []) or [])
        for part in reversed(parts):
            function_call = getattr(part, "function_call", None)
            if not function_call:
                continue
            if getattr(function_call, "name", "") != "calculate_average_macro_nutrient_per_day":
                continue

            args = getattr(function_call, "args", None)
            if isinstance(args, dict):
                return dict(args)
            return None
    return None


def _check_calculator_feasibility(
    calculator_response: Dict[str, Any],
    mfp_payload: Optional[Dict[str, Any]],
) -> bool:
    if not isinstance(calculator_response, dict):
        return False

    achieved = calculator_response.get("average_macro_nutrient_from_calculated_quantity_per_day")
    if not isinstance(achieved, dict):
        return False

    point_targets = _build_target_bundle_from_mfp_payload(mfp_payload) if mfp_payload else {}

    tolerance_by_key = {
        "calories": 10.0,
        "protein": 5.0,
        "carbohydrates": 2.0,
        "total_fat": 5.0,
    }

    for key, tolerance in tolerance_by_key.items():
        target = point_targets.get(key)
        actual = achieved.get(key)
        if not isinstance(target, (int, float)):
            continue
        if not isinstance(actual, (int, float)):
            return False
        if abs(float(actual) - float(target)) > tolerance:
            return False

    return True


def _calculate_feasibility_score(
    calculator_response: Dict[str, Any],
    mfp_payload: Optional[Dict[str, Any]],
) -> float:
    if not isinstance(calculator_response, dict):
        return float("inf")

    achieved = calculator_response.get("average_macro_nutrient_from_calculated_quantity_per_day")
    if not isinstance(achieved, dict):
        return float("inf")

    point_targets = _build_target_bundle_from_mfp_payload(mfp_payload) if mfp_payload else {}

    tolerance_by_key = {
        "calories": 10.0,
        "protein": 5.0,
        "carbohydrates": 2.0,
        "total_fat": 5.0,
    }

    score = 0.0
    for key, tolerance in tolerance_by_key.items():
        target = point_targets.get(key)
        actual = achieved.get(key)
        if not isinstance(target, (int, float)):
            continue
        if not isinstance(actual, (int, float)):
            score += 1e6
            continue

        diff = abs(float(actual) - float(target))
        score += max(0.0, diff - tolerance)

    return score


def _build_baseline_formatted_response(
    tool_response: Dict[str, Any],
    tool_context: ToolContext,
) -> str:
    mfp_payload = _extract_latest_function_response_from_history(tool_context, "calculate_mfp_macros")
    point_targets: Dict[str, float] = {}
    if mfp_payload:
        point_targets = _build_target_bundle_from_mfp_payload(mfp_payload)

    achieved_targets = tool_response.get("average_macro_nutrient_from_calculated_quantity_per_day")
    if not isinstance(achieved_targets, dict):
        achieved_targets = {}

    lines = [
        "Summary:",
        _build_summary_table(point_targets, achieved_targets),
    ]

    if _check_calculator_feasibility(tool_response, mfp_payload):
        lines.extend(["", "- The plan meets the current calorie and macro targets."])
    else:
        infeasibility_details = {}
        for key, detail_key in [
            ("calories", "average_daily_calorie_difference"),
            ("protein", "daily_protein_difference"),
            ("carbohydrates", "daily_carbohydrate_difference"),
            ("total_fat", "daily_total_fat_difference"),
        ]:
            target = point_targets.get(key)
            actual = achieved_targets.get(key)
            if isinstance(target, (int, float)) and isinstance(actual, (int, float)):
                infeasibility_details[detail_key] = float(actual) - float(target)

        infeasibility_lines = _build_infeasibility_lines(
            point_targets,
            achieved_targets,
            infeasibility_details,
        )
        if infeasibility_lines:
            lines.extend(["", *infeasibility_lines])

    products_by_index = _extract_products_from_history(tool_context)
    meal_lines = _build_meal_lines(tool_response.get("calculated_quantity_per_day", {}), products_by_index)
    if meal_lines:
        lines.extend(["", *meal_lines])

    return "\n".join(lines).strip()


def _baseline_before_model_callback(
    callback_context: CallbackContext,
    llm_request: LlmRequest,
) -> Optional[LlmResponse]:
    del callback_context

    calculator_response, run_count = _extract_best_calculator_response_and_run_count(llm_request)

    if not isinstance(calculator_response, dict):
        return None

    formatted_response = calculator_response.get("formatted_response", "")
    is_feasible = calculator_response.get("is_feasible", False)

    if (run_count >= 3 or is_feasible) and isinstance(formatted_response, str):
        if formatted_response.strip():
            return LlmResponse(
                content=types.Content(
                    role="model",
                    parts=[types.Part(text=formatted_response.strip())],
                )
            )

    return None


def _baseline_after_tool_callback(
    tool: BaseTool,
    args: Dict[str, Any],
    tool_context: ToolContext,
    tool_response: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
    if getattr(tool, "name", "") != "calculate_average_macro_nutrient_per_day":
        return None
    if not isinstance(tool_response, dict):
        return None

    mfp_payload = _extract_latest_function_response_from_history(tool_context, "calculate_mfp_macros")

    updated_response = dict(tool_response)

    calculated_quantity_json = args.get("calculated_quantity_json", "")
    if isinstance(calculated_quantity_json, str) and calculated_quantity_json.strip():
        try:
            import json

            calculated_data = json.loads(calculated_quantity_json)
            if isinstance(calculated_data, dict):
                updated_response["calculated_quantity_per_day"] = calculated_data
        except (json.JSONDecodeError, ValueError):
            pass

    updated_response["formatted_response"] = _build_baseline_formatted_response(
        tool_response=updated_response,
        tool_context=tool_context,
    )

    is_feasible = _check_calculator_feasibility(tool_response, mfp_payload)
    updated_response["is_feasible"] = is_feasible
    updated_response["feasibility_score"] = _calculate_feasibility_score(tool_response, mfp_payload)

    return updated_response


agent = Agent(
    model=LiteLlm(
        model=os.getenv("OPENROUTER_API_MODEL"), 
        api_key=os.getenv("OPENROUTER_API_KEY"),
        api_base="https://openrouter.ai/api/v1",
        extra_body={
        "session_id": "LLMNutritionPlanner"  
        }
    ),
  name='LLMNutritionPlanner',
    description=(
        "Tells the nutrients, descriptions, name, cost of food items and helps with meal planning."
    ),
    instruction=PROMPT_PATH.read_text(encoding="utf-8"),
    tools=[find_ingredient, calculate_mfp_macros, calculate_average_macro_nutrient_per_day],
    before_model_callback=_baseline_before_model_callback,
    after_tool_callback=_baseline_after_tool_callback,
)

root_agent = agent

