import sys
import os
import re
from typing import Any, Dict, Optional

from google.genai import types
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.models.lite_llm import LiteLlm
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext
from utils.ingredient_tool import find_ingredient
from utils.calculate_health_canada_dri import calculate_health_canada_dri
from utils.calculator import calculate_average_macro_nutrient_per_day
from utils.optimizer import _build_target_bundle_from_dri_payload
from utils.optimizer import _extract_latest_function_response_from_history
from utils.optimizer import _extract_products_from_history


def _load_agent_instruction() -> str:
    prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "dri_prompt.txt")
    with open(prompt_path, "r", encoding="utf-8") as prompt_file:
        return prompt_file.read().strip()


def _format_number(value: Any, decimals: int = 2) -> str:
    try:
        number = float(value)
    except (TypeError, ValueError):
        return "N/A"

    rounded = round(number)
    if abs(number - rounded) < 1e-6:
        return str(int(rounded))
    return f"{number:.{decimals}f}".rstrip("0").rstrip(".")


def _natural_sort_key(value: str) -> list[Any]:
    return [int(part) if part.isdigit() else part.lower() for part in re.split(r"(\d+)", str(value))]


def _humanize_identifier(value: Any, fallback_prefix: str) -> str:
    text = str(value).strip()
    if not text:
        return fallback_prefix

    normalized = text.replace("::", " ").replace("_", " ")
    normalized = re.sub(r"\s+", " ", normalized).strip()
    if not normalized:
        return fallback_prefix

    return normalized.title()


def _display_product_name(product: Optional[Dict[str, Any]], index: Any) -> str:
    if isinstance(product, dict):
        name = str(product.get("name", "")).strip()
        if name:
            return name
    return f"Ingredient {index}"


def _normalize_position_label(value: Any, prefix: str, position: int) -> str:
    label = _humanize_identifier(value, f"{prefix} {position}")
    if not label.lower().startswith(prefix.lower()):
        return f"{prefix} {position}"
    return label


def _first_option_items(selected_options: Any) -> list[Dict[str, Any]]:
    if not isinstance(selected_options, dict):
        return []

    for option_name in sorted(selected_options.keys(), key=_natural_sort_key):
        option_data = selected_options.get(option_name)

        if isinstance(option_data, dict):
            items = option_data.get("items")
            if isinstance(items, list) and items:
                return items
        elif isinstance(option_data, list) and option_data:
            return option_data

    return []


def _first_numeric_value(values: Dict[str, Any], keys: list[str]) -> Optional[float]:
    for key in keys:
        candidate = values.get(key)
        if isinstance(candidate, (int, float)):
            return float(candidate)
    return None


def _format_target_value(
    key: str,
    unit: str,
    point_targets: Dict[str, float],
    range_targets: Dict[str, Dict[str, float]],
) -> str:
    bounds = range_targets.get(key)
    if isinstance(bounds, dict):
        lower = bounds.get("lower")
        upper = bounds.get("upper")
        if lower is not None and upper is not None:
            return f"{_format_number(lower)}-{_format_number(upper)} {unit}"

    if key in point_targets:
        return f"{_format_number(point_targets[key])} {unit}"

    return "N/A"


def _format_actual_value(key: str, unit: str, achieved_targets: Dict[str, Any]) -> str:
    if key not in achieved_targets:
        return "N/A"
    return f"{_format_number(achieved_targets[key])} {unit}"


def _build_summary_table(
    point_targets: Dict[str, float],
    range_targets: Dict[str, Dict[str, float]],
    achieved_targets: Dict[str, Any],
) -> str:
    rows = [
        ("Protein", "protein", "g"),
        ("Carbs", "carbohydrates", "g"),
        ("Fat", "total_fat", "g"),
        ("Calories", "calories", "kcal"),
        ("Fibre", "total_fibre", "g"),
    ]

    lines = [
        "| Parameter | Target | Actual |",
        "| --- | --- | --- |",
    ]
    for label, key, unit in rows:
        lines.append(
            f"| {label} | {_format_target_value(key, unit, point_targets, range_targets)} | {_format_actual_value(key, unit, achieved_targets)} |"
        )
    return "\n".join(lines)


def _build_meal_lines(
    calculated_quantity_per_day: Dict[str, Any],
    products_by_index: Dict[int, Dict[str, Any]],
) -> list[str]:
    if not isinstance(calculated_quantity_per_day, dict) or not calculated_quantity_per_day:
        return []

    lines: list[str] = []
    for day_position, day_name in enumerate(sorted(calculated_quantity_per_day.keys(), key=_natural_sort_key), start=1):
        day_value = calculated_quantity_per_day.get(day_name)

        day_label = _normalize_position_label(day_name, "Day", day_position)
        lines.append(f"{day_label}:")

        if isinstance(day_value, list):
            meal_groups: Dict[str, list[Dict[str, Any]]] = {}
            for item in day_value:
                if not isinstance(item, dict):
                    continue
                meal_name = item.get("meal_name", "Unspecified Meal")
                if meal_name not in meal_groups:
                    meal_groups[meal_name] = []
                meal_groups[meal_name].append(item)
        elif isinstance(day_value, dict):
            meal_groups = {}
            for meal_name, meal_data in day_value.items():
                if isinstance(meal_data, list):
                    meal_groups[meal_name] = meal_data
                elif isinstance(meal_data, dict):
                    items = _first_option_items(meal_data)
                    if items:
                        meal_groups[meal_name] = items
        else:
            continue

        for meal_position, (meal_name, selected_items) in enumerate(sorted(meal_groups.items(), key=lambda kv: _natural_sort_key(str(kv[0]))), start=1):
            if not isinstance(selected_items, list):
                continue

            meal_label = _normalize_position_label(meal_name, "Meal", meal_position)
            ingredient_texts: list[str] = []

            for item in selected_items:
                if not isinstance(item, dict):
                    continue
                raw_index = item.get("index")
                raw_qty = item.get("qty") or item.get("quantity")
                try:
                    index = int(raw_index)
                    qty = float(raw_qty)
                except (TypeError, ValueError):
                    continue

                ingredient_name = _display_product_name(products_by_index.get(index), index)
                ingredient_texts.append(f"{ingredient_name} ({_format_number(qty)} g)")

            joined_ingredients = ", ".join(ingredient_texts)
            if str(meal_name).strip().lower() != meal_label.lower():
                lines.append(f"- {meal_label} ({meal_name}): [{joined_ingredients}]")
            else:
                lines.append(f"- {meal_label}: [{joined_ingredients}]")

        if lines and lines[-1] != "":
            lines.append("")

    while lines and not lines[-1].strip():
        lines.pop()
    return lines


def _build_adjustment_suggestions(
    achieved_targets: Dict[str, Any],
    range_targets: Dict[str, Dict[str, float]],
    infeasibility_details: Dict[str, Any],
) -> list[str]:
    suggestions: list[str] = []

    calorie_difference = _first_numeric_value(
        infeasibility_details,
        ["average_daily_calorie_difference", "calorie_difference"],
    )
    if calorie_difference is not None:
        if calorie_difference < 0:
            suggestions.append("Increase overall calories by raising portions of calorie-dense carb or fat sources.")
        elif calorie_difference > 0:
            suggestions.append("Reduce overall calories by lowering portions of the most energy-dense ingredients.")

    fibre_difference = _first_numeric_value(
        infeasibility_details,
        ["daily_total_fibre_difference", "fibre_difference"],
    )
    if fibre_difference is not None:
        if fibre_difference < 0:
            suggestions.append("Add more fibre-rich foods such as legumes, vegetables, fruit, or whole grains.")
        elif fibre_difference > 0:
            suggestions.append("Reduce the highest-fibre ingredients or swap some whole-grain or legume portions for lower-fibre alternatives.")

    macro_adjustments = [
        ("protein", "Increase lean protein sources.", "Reduce the most protein-dense ingredients."),
        ("carbohydrates", "Increase carb portions such as grains, fruit, or starchy vegetables.", "Reduce carb-heavy ingredients."),
        ("total_fat", "Increase healthy fat sources such as nuts, seeds, avocado, or oils.", "Reduce added fats and the richest fat sources."),
    ]
    for macro_key, increase_text, decrease_text in macro_adjustments:
        bounds = range_targets.get(macro_key)
        actual = achieved_targets.get(macro_key)
        if not isinstance(bounds, dict) or not isinstance(actual, (int, float)):
            continue
        lower = bounds.get("lower")
        upper = bounds.get("upper")
        if isinstance(lower, (int, float)) and actual < float(lower) - 1e-6:
            suggestions.append(increase_text)
        elif isinstance(upper, (int, float)) and actual > float(upper) + 1e-6:
            suggestions.append(decrease_text)

    deduped = list(dict.fromkeys(suggestions))

    if not deduped:
        deduped.append("Widen ingredient quantity bounds or swap in ingredients with more compatible calorie and macro profiles.")
    return deduped


def _build_infeasibility_lines(
    achieved_targets: Dict[str, Any],
    range_targets: Dict[str, Dict[str, float]],
    infeasibility_details: Dict[str, Any],
) -> list[str]:
    if not isinstance(infeasibility_details, dict) or not infeasibility_details:
        return []

    lines = ["Deviations:"]
    calorie_difference = _first_numeric_value(
        infeasibility_details,
        ["average_daily_calorie_difference", "calorie_difference"],
    )
    if calorie_difference is not None:
        lines.append(f"- Daily calorie difference: {_format_number(calorie_difference)} kcal")

    fibre_difference = _first_numeric_value(
        infeasibility_details,
        ["daily_total_fibre_difference", "fibre_difference"],
    )
    if fibre_difference is not None:
        lines.append(f"- Daily total_fibre difference: {_format_number(fibre_difference)} g")

    deviation_labels = [
        (["daily_protein_range_deviation", "protein_range_deviation"], "Daily protein range deviation"),
        (["daily_carbohydrate_range_deviation", "daily_carbohydrates_range_deviation", "carbohydrates_range_deviation"], "Daily carbohydrate range deviation"),
        (["daily_total_fat_range_deviation", "total_fat_range_deviation", "fat_range_deviation"], "Daily total_fat range deviation"),
    ]
    for keys, label in deviation_labels:
        value = _first_numeric_value(infeasibility_details, keys)
        if value is not None:
            lines.append(f"- {label}: {_format_number(value)} g")

    lines.append("Adjustments:")
    for suggestion in _build_adjustment_suggestions(achieved_targets, range_targets, infeasibility_details):
        lines.append(f"- {suggestion}")
    return lines


def _range_deviation(
    actual: Any,
    bounds: Optional[Dict[str, float]],
) -> Optional[float]:
    if not isinstance(actual, (int, float)):
        return None
    if not isinstance(bounds, dict):
        return None

    lower = bounds.get("lower")
    upper = bounds.get("upper")
    if isinstance(lower, (int, float)) and float(actual) < float(lower):
        return float(lower) - float(actual)
    if isinstance(upper, (int, float)) and float(actual) > float(upper):
        return float(actual) - float(upper)
    return 0.0


def _extract_best_calculator_response_and_run_count(
    llm_request: LlmRequest,
) -> tuple[Optional[Dict[str, Any]], int]:
    """Extract the best calculator response and count total calculation attempts.
    
    Returns the best calculator response (preferring feasible solutions with lowest score)
    and the total number of calculate_average_macro_nutrient_per_day tool calls.
    This enforces retry logic: up to 3 attempts or until feasible solution is found.
    """
    best_response = None
    best_key: Optional[tuple[int, float]] = None
    run_count = 0

    for content in llm_request.contents:
        parts = list(getattr(content, "parts", []) or [])

        for part in parts:
            function_response = getattr(part, "function_response", None)
            if not function_response:
                continue

            if getattr(function_response, "name", "") == "calculate_average_macro_nutrient_per_day":
                run_count += 1
                payload = getattr(function_response, "response", None)
                if not isinstance(payload, dict):
                    continue

                is_feasible = bool(payload.get("is_feasible", False))
                score_value = payload.get("feasibility_score")
                if isinstance(score_value, (int, float)):
                    score = float(score_value)
                else:
                    score = 0.0 if is_feasible else float("inf")

                candidate_key = (0 if is_feasible else 1, score)
                if best_key is None or candidate_key <= best_key:
                    best_key = candidate_key
                    best_response = payload

    return best_response, run_count


def _extract_latest_calculator_call_args(
    llm_request: LlmRequest,
) -> Optional[Dict[str, Any]]:
    for content in reversed(llm_request.contents):
        parts = list(getattr(content, "parts", []) or [])
        for part in reversed(parts):
            function_call = getattr(part, "function_call", None)
            if not function_call:
                continue
            if getattr(function_call, "name", "") != "calculate_average_macro_nutrient_per_day":
                continue

            args = getattr(function_call, "args", None)
            if isinstance(args, dict):
                return dict(args)
            return None
    return None


def _is_macro_in_range(
    achieved: float,
    bounds: Optional[Dict[str, float]],
    point_target: Optional[float],
) -> bool:
    if isinstance(bounds, dict):
        lower = bounds.get("lower")
        upper = bounds.get("upper")
        if lower is not None and upper is not None:
            return float(lower) <= achieved <= float(upper)

    if point_target is not None:
        return abs(achieved - float(point_target)) < 1e-6

    return True


def _check_calculator_feasibility(
    calculator_response: Dict[str, Any],
    dri_payload: Optional[Dict[str, Any]],
) -> bool:
    if not isinstance(calculator_response, dict):
        return False

    achieved = calculator_response.get("average_macro_nutrient_from_calculated_quantity_per_day")
    if not isinstance(achieved, dict):
        return False

    if dri_payload:
        point_targets, range_targets = _build_target_bundle_from_dri_payload(dri_payload)
    else:
        point_targets = {}
        range_targets = {}

    fibre_target = point_targets.get("total_fibre")
    fibre_achieved = achieved.get("total_fibre", 0)
    if fibre_target is not None:
        fibre_error = abs(float(fibre_achieved) - float(fibre_target))
        if fibre_error > 5.0:
            return False

    calorie_target = point_targets.get("calories")
    calorie_achieved = achieved.get("calories", 0)
    if calorie_target is not None:
        calorie_error = abs(float(calorie_achieved) - float(calorie_target))
        if calorie_error > 10.0:
            return False

    macros_to_check = {
        "protein": achieved.get("protein", 0),
        "carbohydrates": achieved.get("carbohydrates", 0),
        "total_fat": achieved.get("total_fat", 0),
    }

    for macro_key, macro_achieved in macros_to_check.items():
        bounds = range_targets.get(macro_key)
        point_target = point_targets.get(macro_key)
        if not _is_macro_in_range(float(macro_achieved), bounds, point_target):
            return False

    return True


def _calculate_feasibility_score(
    calculator_response: Dict[str, Any],
    dri_payload: Optional[Dict[str, Any]],
) -> float:
    if not isinstance(calculator_response, dict):
        return float("inf")

    achieved = calculator_response.get("average_macro_nutrient_from_calculated_quantity_per_day")
    if not isinstance(achieved, dict):
        return float("inf")

    if dri_payload:
        point_targets, range_targets = _build_target_bundle_from_dri_payload(dri_payload)
    else:
        point_targets = {}
        range_targets = {}

    score = 0.0

    fibre_target = point_targets.get("total_fibre")
    fibre_achieved = achieved.get("total_fibre", 0)
    if isinstance(fibre_target, (int, float)):
        fibre_error = abs(float(fibre_achieved) - float(fibre_target))
        score += max(0.0, fibre_error - 5.0)

    calorie_target = point_targets.get("calories")
    calorie_achieved = achieved.get("calories", 0)
    if isinstance(calorie_target, (int, float)):
        calorie_error = abs(float(calorie_achieved) - float(calorie_target))
        score += max(0.0, calorie_error - 10.0)

    for macro_key in ["protein", "carbohydrates", "total_fat"]:
        macro_achieved = achieved.get(macro_key)
        if not isinstance(macro_achieved, (int, float)):
            score += 1e6
            continue

        bounds = range_targets.get(macro_key)
        if isinstance(bounds, dict):
            lower = bounds.get("lower")
            upper = bounds.get("upper")
            if isinstance(lower, (int, float)) and float(macro_achieved) < float(lower):
                score += float(lower) - float(macro_achieved)
            if isinstance(upper, (int, float)) and float(macro_achieved) > float(upper):
                score += float(macro_achieved) - float(upper)
            continue

        point_target = point_targets.get(macro_key)
        if isinstance(point_target, (int, float)):
            score += abs(float(macro_achieved) - float(point_target))

    return score


def _build_baseline_formatted_response(
    tool_response: Dict[str, Any],
    tool_context: ToolContext,
) -> str:
    dri_payload = _extract_latest_function_response_from_history(tool_context, "calculate_health_canada_dri")
    point_targets: Dict[str, float] = {}
    range_targets: Dict[str, Dict[str, float]] = {}
    if dri_payload:
        raw_point_targets, raw_range_targets = _build_target_bundle_from_dri_payload(dri_payload)
        point_targets = {k: v for k, v in raw_point_targets.items() if k in {"calories", "protein", "carbohydrates", "total_fat", "total_fibre"}}
        range_targets = {k: v for k, v in raw_range_targets.items() if k in {"protein", "carbohydrates", "total_fat"}}

    achieved_targets = tool_response.get("average_macro_nutrient_from_calculated_quantity_per_day")
    if not isinstance(achieved_targets, dict):
        achieved_targets = {}

    lines = [
        "Summary:",
        _build_summary_table(point_targets, range_targets, achieved_targets),
    ]

    if _check_calculator_feasibility(tool_response, dri_payload):
        lines.extend(["", "- The plan meets the current calorie, fibre, and macro targets."])
    else:
        infeasibility_details = {}
        if point_targets.get("total_fibre") is not None:
            infeasibility_details["daily_total_fibre_difference"] = float(achieved_targets.get("total_fibre", 0)) - float(point_targets.get("total_fibre", 0))
        if point_targets.get("calories") is not None:
            infeasibility_details["average_daily_calorie_difference"] = float(achieved_targets.get("calories", 0)) - float(point_targets.get("calories", 0))

        protein_deviation = _range_deviation(achieved_targets.get("protein"), range_targets.get("protein"))
        if protein_deviation is not None:
            infeasibility_details["daily_protein_range_deviation"] = protein_deviation

        carb_deviation = _range_deviation(achieved_targets.get("carbohydrates"), range_targets.get("carbohydrates"))
        if carb_deviation is not None:
            infeasibility_details["daily_carbohydrate_range_deviation"] = carb_deviation

        fat_deviation = _range_deviation(achieved_targets.get("total_fat"), range_targets.get("total_fat"))
        if fat_deviation is not None:
            infeasibility_details["daily_total_fat_range_deviation"] = fat_deviation

        infeasibility_lines = _build_infeasibility_lines(
            achieved_targets,
            range_targets,
            infeasibility_details,
        )
        if infeasibility_lines:
            lines.extend(["", *infeasibility_lines])

    products_by_index = _extract_products_from_history(tool_context)
    meal_lines = _build_meal_lines(tool_response.get("calculated_quantity_per_day", {}), products_by_index)
    if meal_lines:
        lines.extend(["", *meal_lines])

    return "\n".join(lines).strip()


def _baseline_before_model_callback(
    callback_context: CallbackContext,
    llm_request: LlmRequest,
) -> Optional[LlmResponse]:
    del callback_context

    calculator_response, run_count = _extract_best_calculator_response_and_run_count(llm_request)

    if not isinstance(calculator_response, dict):
        return None

    formatted_response = calculator_response.get("formatted_response", "")
    is_feasible = calculator_response.get("is_feasible", False)

    # Return immediately on feasibility or after the third attempt.
    if (run_count >= 3 or is_feasible) and isinstance(formatted_response, str):
        if formatted_response.strip():
            return LlmResponse(
                content=types.Content(
                    role="model",
                    parts=[types.Part(text=formatted_response.strip())],
                )
            )

    return None


def _baseline_after_tool_callback(
    tool: BaseTool,
    args: Dict[str, Any],
    tool_context: ToolContext,
    tool_response: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
    if getattr(tool, "name", "") != "calculate_average_macro_nutrient_per_day":
        return None
    if not isinstance(tool_response, dict):
        return None

    dri_payload = _extract_latest_function_response_from_history(tool_context, "calculate_health_canada_dri")

    updated_response = dict(tool_response)

    calculated_quantity_json = args.get("calculated_quantity_json", "")
    if isinstance(calculated_quantity_json, str) and calculated_quantity_json.strip():
        try:
            import json

            calculated_data = json.loads(calculated_quantity_json)
            if isinstance(calculated_data, dict):
                updated_response["calculated_quantity_per_day"] = calculated_data
        except (json.JSONDecodeError, ValueError):
            pass

    updated_response["formatted_response"] = _build_baseline_formatted_response(
        tool_response=updated_response,
        tool_context=tool_context,
    )

    is_feasible = _check_calculator_feasibility(tool_response, dri_payload)
    updated_response["is_feasible"] = is_feasible
    updated_response["feasibility_score"] = _calculate_feasibility_score(tool_response, dri_payload)

    return updated_response


agent = Agent(
    model=LiteLlm(
        model=os.getenv("OPENROUTER_API_MODEL"),
        api_key=os.getenv("OPENROUTER_API_KEY"),
        api_base="https://openrouter.ai/api/v1",
        extra_body={
        "session_id": "LLMNutritionPlanner"  # Enables immediate provider sticky routing
        }
    ),
  name='LLMNutritionPlanner',
    description=(
        "Tells the nutrients, descriptions, name, cost of food items and helps with meal planning."
    ),
    instruction=_load_agent_instruction(),
    tools=[find_ingredient, calculate_health_canada_dri, calculate_average_macro_nutrient_per_day],
    before_model_callback=_baseline_before_model_callback,
    after_tool_callback=_baseline_after_tool_callback,
)

root_agent = agent

