import sys
import os
import re
import time
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Dict
from typing import Optional
from google.adk.models.lite_llm import LiteLlm
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from google.genai import types
from google.adk.agents.llm_agent import Agent
from google.adk.agents.callback_context import CallbackContext
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.tools.base_tool import BaseTool
from google.adk.tools.tool_context import ToolContext

from utils.ingredient_tool import find_ingredient
from utils.optimizer import _build_target_bundle_from_dri_payload
from utils.optimizer import _extract_latest_function_response_from_history
from utils.optimizer import _extract_products_from_history
from utils.optimizer import _filter_optimizer_target_ranges
from utils.optimizer import _filter_optimizer_targets
from utils.optimizer import optimize_quantity
# from utils.log_chat import analyze_history
# from utils.calculator import calculate_nutrients, calculate_average_macro_nutrient_per_day
from utils.calculate_health_canada_dri import calculate_health_canada_dri


def _load_agent_instruction() -> str:
  prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "dri_prompt.txt")
  with open(prompt_path, "r", encoding="utf-8") as prompt_file:
    return prompt_file.read().strip()


def _format_number(value: Any, decimals: int = 2) -> str:
  try:
    number = float(value)
  except (TypeError, ValueError):
    return "N/A"

  rounded = round(number)
  if abs(number - rounded) < 1e-6:
    return str(int(rounded))
  return f"{number:.{decimals}f}".rstrip("0").rstrip(".")


def _natural_sort_key(value: str) -> list[Any]:
  return [int(part) if part.isdigit() else part.lower() for part in re.split(r"(\d+)", str(value))]


def _humanize_identifier(value: Any, fallback_prefix: str) -> str:
  text = str(value).strip()
  if not text:
    return fallback_prefix

  normalized = text.replace("::", " ").replace("_", " ")
  normalized = re.sub(r"\s+", " ", normalized).strip()
  if not normalized:
    return fallback_prefix

  return normalized.title()


def _display_product_name(product: Optional[Dict[str, Any]], index: Any) -> str:
  if isinstance(product, dict):
    name = str(product.get("name", "")).strip()
    if name:
      return name
  return f"Ingredient {index}"


def _format_target_value(
  key: str,
  unit: str,
  point_targets: Dict[str, float],
  range_targets: Dict[str, Dict[str, float]],
) -> str:
  bounds = range_targets.get(key)
  if isinstance(bounds, dict):
    lower = bounds.get("lower")
    upper = bounds.get("upper")
    if lower is not None and upper is not None:
      return f"{_format_number(lower)}-{_format_number(upper)} {unit}"

  if key in point_targets:
    return f"{_format_number(point_targets[key])} {unit}"

  return "N/A"


def _format_actual_value(key: str, unit: str, achieved_targets: Dict[str, Any]) -> str:
  if key not in achieved_targets:
    return "N/A"
  return f"{_format_number(achieved_targets[key])} {unit}"


def _build_summary_table(
  point_targets: Dict[str, float],
  range_targets: Dict[str, Dict[str, float]],
  achieved_targets: Dict[str, Any],
) -> str:
  rows = [
    ("Protein", "protein", "g"),
    ("Carbs", "carbohydrates", "g"),
    ("Fat", "total_fat", "g"),
    ("Calories", "calories", "kcal"),
    ("Fibre", "total_fibre", "g"),
  ]

  lines = [
    "| Parameter | Target | Actual |",
    "| --- | --- | --- |",
  ]
  for label, key, unit in rows:
    lines.append(
      f"| {label} | {_format_target_value(key, unit, point_targets, range_targets)} | {_format_actual_value(key, unit, achieved_targets)} |"
    )
  return "\n".join(lines)


def _build_meal_lines(
  calculated_quantity_per_day: Dict[str, Any],
  products_by_index: Dict[int, Dict[str, Any]],
) -> list[str]:
  if not isinstance(calculated_quantity_per_day, dict) or not calculated_quantity_per_day:
    return []

  lines: list[str] = []
  for day_position, day_name in enumerate(sorted(calculated_quantity_per_day.keys(), key=_natural_sort_key), start=1):
    meals = calculated_quantity_per_day.get(day_name)
    if not isinstance(meals, dict):
      continue

    day_label = _humanize_identifier(day_name, f"Day {day_position}")
    if not day_label.lower().startswith("day"):
      day_label = f"Day {day_position}"
    lines.append(f"{day_label}:")

    for meal_position, meal_name in enumerate(sorted(meals.keys(), key=_natural_sort_key), start=1):
      selected_options = meals.get(meal_name)
      meal_label = _humanize_identifier(meal_name, f"Meal {meal_position}")
      if not meal_label.lower().startswith("meal"):
        meal_label = f"Meal {meal_position}"

      selected_items: list[Dict[str, Any]] = []
      if isinstance(selected_options, dict):
        for option_name in sorted(selected_options.keys(), key=_natural_sort_key):
          option_items = selected_options.get(option_name)
          if isinstance(option_items, list) and option_items:
            selected_items = option_items
            break

      ingredient_texts: list[str] = []
      for item in selected_items:
        if not isinstance(item, dict):
          continue
        raw_index = item.get("index")
        raw_qty = item.get("qty")
        try:
          index = int(raw_index)
          qty = float(raw_qty)
        except (TypeError, ValueError):
          continue

        ingredient_name = _display_product_name(products_by_index.get(index), index)
        ingredient_texts.append(f"{ingredient_name} ({_format_number(qty)} g)")

      joined_ingredients = ", ".join(ingredient_texts)
      lines.append(f"- {meal_label}: [{joined_ingredients}]")

    if lines and lines[-1] != "":
      lines.append("")

  while lines and not lines[-1].strip():
    lines.pop()
  return lines


def _build_adjustment_suggestions(
  achieved_targets: Dict[str, Any],
  range_targets: Dict[str, Dict[str, float]],
  infeasibility_details: Dict[str, Any],
) -> list[str]:
  suggestions: list[str] = []

  calorie_difference = infeasibility_details.get("average_daily_calorie_difference")
  if not isinstance(calorie_difference, (int, float)):
    calorie_difference = infeasibility_details.get("calorie_difference")
  if isinstance(calorie_difference, (int, float)):
    if calorie_difference < 0:
      suggestions.append("Increase overall calories by raising portions of calorie-dense carb or fat sources.")
    elif calorie_difference > 0:
      suggestions.append("Reduce overall calories by lowering portions of the most energy-dense ingredients.")

  fibre_difference = infeasibility_details.get("daily_total_fibre_difference")
  if not isinstance(fibre_difference, (int, float)):
    fibre_difference = infeasibility_details.get("fibre_difference")
  if isinstance(fibre_difference, (int, float)):
    if fibre_difference < 0:
      suggestions.append("Add more fibre-rich foods such as legumes, vegetables, fruit, or whole grains.")
    elif fibre_difference > 0:
      suggestions.append("Reduce the highest-fibre ingredients or swap some whole-grain or legume portions for lower-fibre alternatives.")

  macro_adjustments = [
    ("protein", "Increase lean protein sources.", "Reduce the most protein-dense ingredients."),
    ("carbohydrates", "Increase carb portions such as grains, fruit, or starchy vegetables.", "Reduce carb-heavy ingredients."),
    ("total_fat", "Increase healthy fat sources such as nuts, seeds, avocado, or oils.", "Reduce added fats and the richest fat sources."),
  ]
  for macro_key, increase_text, decrease_text in macro_adjustments:
    bounds = range_targets.get(macro_key)
    actual = achieved_targets.get(macro_key)
    if not isinstance(bounds, dict) or not isinstance(actual, (int, float)):
      continue
    lower = bounds.get("lower")
    upper = bounds.get("upper")
    if isinstance(lower, (int, float)) and actual < float(lower) - 1e-6:
      suggestions.append(increase_text)
    elif isinstance(upper, (int, float)) and actual > float(upper) + 1e-6:
      suggestions.append(decrease_text)

  deduped: list[str] = []
  for suggestion in suggestions:
    if suggestion not in deduped:
      deduped.append(suggestion)

  if not deduped:
    deduped.append("Widen ingredient quantity bounds or swap in ingredients with more compatible calorie and macro profiles.")
  return deduped


def _build_infeasibility_lines(
  achieved_targets: Dict[str, Any],
  range_targets: Dict[str, Dict[str, float]],
  infeasibility_details: Dict[str, Any],
) -> list[str]:
  if not isinstance(infeasibility_details, dict) or not infeasibility_details:
    return []

  lines = ["Deviations:"]
  calorie_difference = infeasibility_details.get("average_daily_calorie_difference")
  if not isinstance(calorie_difference, (int, float)):
    calorie_difference = infeasibility_details.get("calorie_difference")
  if isinstance(calorie_difference, (int, float)):
    lines.append(f"- Daily calorie difference: {_format_number(calorie_difference)} kcal")

  fibre_difference = infeasibility_details.get("daily_total_fibre_difference")
  if not isinstance(fibre_difference, (int, float)):
    fibre_difference = infeasibility_details.get("fibre_difference")
  if isinstance(fibre_difference, (int, float)):
    lines.append(f"- Daily total_fibre difference: {_format_number(fibre_difference)} g")

  deviation_labels = [
    (["daily_protein_range_deviation", "protein_range_deviation"], "Daily protein range deviation"),
    (["daily_carbohydrate_range_deviation", "daily_carbohydrates_range_deviation", "carbohydrates_range_deviation"], "Daily carbohydrate range deviation"),
    (["daily_total_fat_range_deviation", "total_fat_range_deviation", "fat_range_deviation"], "Daily total_fat range deviation"),
  ]
  for keys, label in deviation_labels:
    value = None
    for key in keys:
      candidate = infeasibility_details.get(key)
      if isinstance(candidate, (int, float)):
        value = candidate
        break
    if isinstance(value, (int, float)):
      lines.append(f"- {label}: {_format_number(value)} g")

  lines.append("Adjustments:")
  for suggestion in _build_adjustment_suggestions(achieved_targets, range_targets, infeasibility_details):
    lines.append(f"- {suggestion}")

  return lines


def _build_optimizer_formatted_response(
  tool_response: Dict[str, Any],
  tool_context: ToolContext,
) -> str:
  dri_payload = _extract_latest_function_response_from_history(tool_context, "calculate_health_canada_dri")
  point_targets: Dict[str, float] = {}
  range_targets: Dict[str, Dict[str, float]] = {}
  if dri_payload:
    raw_point_targets, raw_range_targets = _build_target_bundle_from_dri_payload(dri_payload)
    point_targets = _filter_optimizer_targets(raw_point_targets)
    range_targets = _filter_optimizer_target_ranges(raw_range_targets)

  achieved_targets = tool_response.get("achieved_targets_per_day")
  if not isinstance(achieved_targets, dict):
    achieved_targets = {}

  products_by_index = _extract_products_from_history(tool_context)
  meal_lines = _build_meal_lines(tool_response.get("calculated_quantity_per_day", {}), products_by_index)

  lines = [
    "Summary:",
    _build_summary_table(point_targets, range_targets, achieved_targets),
  ]

  status = str(tool_response.get("status", "")).strip().lower()
  if status == "infeasible":
    infeasibility_lines = _build_infeasibility_lines(
      achieved_targets,
      range_targets,
      tool_response.get("infeasibility_details", {}),
    )
    if infeasibility_lines:
      lines.extend(["", *infeasibility_lines])
  else:
    lines.extend(["", "- The plan meets the current calorie, fibre, and macro targets."])

  if meal_lines:
    lines.extend(["", *meal_lines])

  return "\n".join(lines).strip()


def _extract_latest_optimizer_response_from_request(
  llm_request: LlmRequest,
) -> Optional[Dict[str, Any]]:
  for content in reversed(llm_request.contents):
    parts = list(getattr(content, "parts", []) or [])
    saw_function_response = False

    for part in reversed(parts):
      function_response = getattr(part, "function_response", None)
      if not function_response:
        continue

      saw_function_response = True
      if getattr(function_response, "name", "") != "optimize_quantity":
        continue

      payload = getattr(function_response, "response", None)
      if isinstance(payload, dict):
        return payload
      return None

    if saw_function_response:
      return None

    for part in reversed(parts):
      if getattr(part, "function_call", None):
        return None
      text = getattr(part, "text", None)
      if text and str(text).strip():
        return None

  return None


def _extract_best_optimizer_response_and_run_count(
  llm_request: LlmRequest,
) -> tuple[Optional[Dict[str, Any]], int]:
  """Extract the best optimizer response and count total optimization attempts.
  
  Returns the best optimizer response and the total number of optimize_quantity tool calls.
  This enforces retry logic: up to 3 attempts or until feasible solution is found.
  """
  best_response = None
  best_key: Optional[tuple[int, float]] = None
  run_count = 0

  def _infeasibility_score(payload: Dict[str, Any]) -> float:
    status = str(payload.get("status", "")).strip().lower()
    if status == "feasible":
      return 0.0

    details = payload.get("infeasibility_details")
    if not isinstance(details, dict):
      return float("inf")

    total = 0.0
    found_any = False
    for key in (
      "average_daily_calorie_difference",
      "calorie_difference",
      "daily_total_fibre_difference",
      "fibre_difference",
      "daily_protein_range_deviation",
      "protein_range_deviation",
      "daily_carbohydrate_range_deviation",
      "daily_carbohydrates_range_deviation",
      "carbohydrates_range_deviation",
      "daily_total_fat_range_deviation",
      "total_fat_range_deviation",
      "fat_range_deviation",
    ):
      value = details.get(key)
      if isinstance(value, (int, float)):
        total += abs(float(value))
        found_any = True

    return total if found_any else float("inf")

  for content in llm_request.contents:
    parts = list(getattr(content, "parts", []) or [])

    for part in parts:
      function_response = getattr(part, "function_response", None)
      if not function_response:
        continue

      if getattr(function_response, "name", "") == "optimize_quantity":
        run_count += 1
        payload = getattr(function_response, "response", None)
        if isinstance(payload, dict):
          status = str(payload.get("status", "")).strip().lower()
          is_feasible = status == "feasible"
          candidate_key = (0 if is_feasible else 1, _infeasibility_score(payload))
          if best_key is None or candidate_key <= best_key:
            best_key = candidate_key
            best_response = payload

  return best_response, run_count


def _extract_latest_optimizer_call_args(
  llm_request: LlmRequest,
) -> Optional[Dict[str, Any]]:
  for content in reversed(llm_request.contents):
    parts = list(getattr(content, "parts", []) or [])
    for part in reversed(parts):
      function_call = getattr(part, "function_call", None)
      if not function_call:
        continue
      if getattr(function_call, "name", "") != "optimize_quantity":
        continue

      args = getattr(function_call, "args", None)
      if isinstance(args, dict):
        return dict(args)
      return None
  return None


def _optimizer_before_model_callback(
  callback_context: CallbackContext,
  llm_request: LlmRequest,
) -> Optional[LlmResponse]:
  del callback_context

  optimizer_response, run_count = _extract_best_optimizer_response_and_run_count(llm_request)
  if not isinstance(optimizer_response, dict):
    return None

  status = str(optimizer_response.get("status", "")).strip().lower()
  formatted_response = optimizer_response.get("formatted_response")

  # Return immediately on feasibility or after the third attempt.
  if (run_count >= 3 or status == "feasible") and isinstance(formatted_response, str):
    if formatted_response.strip():
      return LlmResponse(
        content=types.Content(
          role="model",
          parts=[types.Part(text=formatted_response.strip())],
        )
      )

  return None


def _optimizer_after_tool_callback(
  tool: BaseTool,
  args: Dict[str, Any],
  tool_context: ToolContext,
  tool_response: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
  del args

  if getattr(tool, "name", "") != "optimize_quantity":
    return None
  if not isinstance(tool_response, dict):
    return None

  status = str(tool_response.get("status", "")).strip().lower()
  if status not in {"feasible", "infeasible"}:
    return None

  updated_response = dict(tool_response)
  updated_response["formatted_response"] = _build_optimizer_formatted_response(
    tool_response=tool_response,
    tool_context=tool_context,
  )
  updated_response["final_message_timestamp"] = time.time()
  updated_response["final_message_timestamp_utc"] = datetime.now(timezone.utc).isoformat()
  return updated_response


agent = Agent(
    model=LiteLlm(
        model=os.getenv("OPENROUTER_API_MODEL"), 
        api_key=os.getenv("OPENROUTER_API_KEY"),
        api_base="https://openrouter.ai/api/v1",
        extra_body={
        "session_id": "CanadaDRIPlanner1"  
        }
    ),
    name='CanadaDRIPlanner',
    description="Tells the nutrients, descriptions, name of food items and helps with meal planning.",
     instruction=_load_agent_instruction(),
    tools=[
      find_ingredient,
      optimize_quantity,
      calculate_health_canada_dri,
    ],
    before_model_callback=_optimizer_before_model_callback,
    after_tool_callback=_optimizer_after_tool_callback,
)

root_agent = agent
