import asyncio
from typing import Any, Dict, List, Optional
import json
from types import SimpleNamespace

from google.adk.tools.tool_context import ToolContext


_MACRO_NUTRIENT_ALIASES = {
	"protein": ("protein", "proteins"),
	"total_fat": ("total_fat", "fat", "fats", "lipid"),
	"total_fibre": ("total_fibre", "total_fiber", "fiber", "fibre", "dietary_fiber", "dietary_fibre"),
	"calories": ("calories", "kcal", "energy_kcal", "energy"),
	"carbohydrates": ("carbohydrates", "carbs", "carbohydrate"),
}


def _to_float(value: Any, default: float = 0.0) -> float:
	try:
		if value is None:
			return default
		return float(value)
	except (TypeError, ValueError):
		return default


def _to_dict(value: Any) -> Dict[str, Any]:
	if isinstance(value, dict):
		return value
	if isinstance(value, str):
		try:
			parsed = json.loads(value)
			return parsed if isinstance(parsed, dict) else {}
		except json.JSONDecodeError:
			return {}
	return {}


def _is_product_record(value: Any) -> bool:
	if not isinstance(value, dict):
		return False
	if "index" not in value:
		return False
	return "name" in value or "nutrition" in value or "nutrition_100g" in value or "cost" in value


def _collect_products(value: Any, out: Dict[int, Dict[str, Any]]) -> None:
	if _is_product_record(value):
		try:
			idx = int(value.get("index"))
		except (TypeError, ValueError):
			idx = None
		if idx is not None:
			out[idx] = value
		return

	if isinstance(value, list):
		for item in value:
			_collect_products(item, out)
		return

	if isinstance(value, dict):
		for nested in value.values():
			_collect_products(nested, out)


def _extract_products_from_history(tool_context: ToolContext) -> Dict[int, Dict[str, Any]]:
	products_by_index: Dict[int, Dict[str, Any]] = {}
	session = tool_context._invocation_context.session
	for event in session.events:
		content = getattr(event, "content", None)
		parts = getattr(content, "parts", None) if content else None
		if not parts:
			continue

		for part in parts:
			function_response = getattr(part, "function_response", None)
			if not function_response:
				continue
			if getattr(function_response, "name", "") != "find_ingredient":
				continue

			payload = getattr(function_response, "response", None)
			if not isinstance(payload, dict):
				continue
			result = payload.get("result")
			_collect_products(result, products_by_index)

	return products_by_index


def _parse_items(items_json: str) -> List[Dict[str, float]]:
	try:
		payload = json.loads(items_json)
	except json.JSONDecodeError as exc:
		raise ValueError("items_json must be valid JSON") from exc

	if not isinstance(payload, list) or not payload:
		raise ValueError("items_json must be a non-empty JSON list")

	parsed: List[Dict[str, float]] = []
	for entry in payload:
		if not isinstance(entry, dict):
			continue

		item_index: Optional[int] = None
		item_quantity: Optional[float] = None

		if "index" in entry and "quantity" in entry:
			try:
				item_index = int(entry.get("index"))
			except (TypeError, ValueError):
				item_index = None
			item_quantity = _to_float(entry.get("quantity"), -1.0)
		elif len(entry) == 1:
			key = next(iter(entry.keys()))
			try:
				item_index = int(key)
			except (TypeError, ValueError):
				item_index = None
			item_quantity = _to_float(entry.get(key), -1.0)

		if item_index is None or item_quantity is None or item_quantity < 0:
			continue

		parsed.append({"index": float(item_index), "quantity": item_quantity})

	if not parsed:
		raise ValueError(
			"No valid items. Use format like [{\"index\":5,\"quantity\":120}] or [{\"5\":120}] (quantity in grams)."
		)
	return parsed


def _extract_quantity_entries(value: Any, out: List[Dict[str, float]]) -> None:
	if isinstance(value, list):
		for item in value:
			_extract_quantity_entries(item, out)
		return

	if not isinstance(value, dict):
		return

	if "index" in value:
		quantity_value = None
		for quantity_key in ("qty", "quantity", "quantity_g"):
			if quantity_key in value:
				quantity_value = _to_float(value.get(quantity_key), -1.0)
				break

		if quantity_value is not None and quantity_value >= 0:
			try:
				item_index = int(value.get("index"))
			except (TypeError, ValueError):
				item_index = None

			if item_index is not None:
				out.append({"index": float(item_index), "quantity": quantity_value})
				return

	for nested in value.values():
		_extract_quantity_entries(nested, out)


def _parse_calculated_quantity_by_day(calculated_quantity_json: str) -> Dict[str, List[Dict[str, float]]]:
	try:
		payload = json.loads(calculated_quantity_json)
	except json.JSONDecodeError as exc:
		raise ValueError("calculated_quantity_json must be valid JSON") from exc

	if not isinstance(payload, dict) or not payload:
		raise ValueError("calculated_quantity_json must be a non-empty JSON object")

	parsed_days: Dict[str, List[Dict[str, float]]] = {}
	has_day_keys = False

	for day_key, day_value in payload.items():
		if not isinstance(day_value, (dict, list)):
			continue

		day_items: List[Dict[str, float]] = []
		_extract_quantity_entries(day_value, day_items)
		if not day_items:
			continue

		day_name = str(day_key)
		if day_name.strip().lower().startswith("day"):
			has_day_keys = True
		parsed_days[day_name] = day_items

	if parsed_days and has_day_keys:
		return parsed_days

	all_items: List[Dict[str, float]] = []
	_extract_quantity_entries(payload, all_items)
	if not all_items:
		raise ValueError(
			"No valid quantities found. Expected entries like {'index': 5, 'qty': 120} inside a day/meal JSON object."
		)

	return {"day1": all_items}


def _extract_macro_values(nutrition: Dict[str, Any]) -> Dict[str, float]:
	normalized = {str(k).strip().lower(): _to_float(v, 0.0) for k, v in nutrition.items()}
	macros: Dict[str, float] = {}
	for macro_key, aliases in _MACRO_NUTRIENT_ALIASES.items():
		for alias in aliases:
			if alias in normalized:
				macros[macro_key] = _to_float(normalized.get(alias), 0.0)
				break
	return macros


def _sum_cost_for_quantity(product: Dict[str, Any], quantity_g: float) -> float:
	cost = _to_dict(product.get("cost"))
	unit = str(cost.get("unit", "")).strip().lower()
	price = _to_float(cost.get("price"), -1.0)
	quantity = _to_float(cost.get("quantity"), 1.0)
	if price < 0 or quantity <= 0:
		return 0.0

	if unit in {"kg", "kilogram", "kilograms"}:
		return quantity_g * (price / (quantity * 1000.0))
	if unit in {"g", "gram", "grams"}:
		return quantity_g * (price / quantity)
	if unit in {"lb", "lbs", "pound", "pounds"}:
		return quantity_g * (price / (quantity * 453.59237))
	if unit in {"each", "unit", "piece"}:
		serving = _to_dict(product.get("serving"))
		metric = serving.get("metric") if isinstance(serving.get("metric"), dict) else {}
		grams_per_piece = _to_float(metric.get("quantity"), 0.0)
		metric_unit = str(metric.get("unit", "")).strip().lower()
		if grams_per_piece > 0 and metric_unit in {"g", "gram", "grams"}:
			pieces = quantity_g / grams_per_piece
			return pieces * (price / quantity)
	return 0.0


async def _calculate_nutrients_impl(items_json: str, tool_context: ToolContext) -> Dict[str, Any]:
	"""
	Calculate total nutrients for selected products and quantities using previous `find_ingredient` results.

	Args:
		items_json: JSON list of product+quantity pairs, quantity in grams.
			Supported shapes:
			- [{"index": 5, "quantity": 120}, {"index": 8, "quantity": 60}]
			- [{"5": 120}, {"8": 60}]
		tool_context: ADK tool context containing session history.

	Returns:
		Dict with `totals`, `items`, and missing indices.
	"""
	if tool_context is None:
		raise ValueError("tool_context is required")

	requested_items = _parse_items(items_json)
	products_by_index = _extract_products_from_history(tool_context)
	if not products_by_index:
		raise ValueError("No prior find_ingredient results found in conversation history")

	totals: Dict[str, float] = {}
	total_cost = 0.0
	missing_indices: List[int] = []
	item_breakdown: List[Dict[str, Any]] = []

	for requested in requested_items:
		index = int(requested["index"])
		quantity_g = _to_float(requested.get("quantity"), 0.0)
		product = products_by_index.get(index)
		if not product:
			missing_indices.append(index)
			continue

		nutrition_100g = _to_dict(product.get("nutrition_100g"))
		if not nutrition_100g:
			nutrition_100g = _to_dict(product.get("nutrition"))
		item_nutrients: Dict[str, float] = {}
		for nutrient, value in nutrition_100g.items():
			per_100g = _to_float(value, 0.0)
			amount = per_100g * (quantity_g / 100.0)
			if amount == 0:
				continue
			item_nutrients[nutrient] = round(amount, 4)
			totals[nutrient] = totals.get(nutrient, 0.0) + amount

		item_cost = _sum_cost_for_quantity(product, quantity_g)
		total_cost += item_cost

		item_breakdown.append(
			{
				"index": index,
				"name": product.get("name", f"product_{index}"),
				"quantity_g": round(quantity_g, 2),
				"cost": round(item_cost, 4),
				"nutrients": item_nutrients,
			}
		)

	rounded_totals = {key: round(value, 4) for key, value in totals.items()}
	rounded_totals["cost"] = round(total_cost, 4)

	return {
		"items": item_breakdown,
		"totals": rounded_totals,
		"missing_indices": sorted(set(missing_indices)),
		"input_count": len(requested_items),
		"matched_count": len(item_breakdown),
	}


async def calculate_nutrients(items_json: str, tool_context: ToolContext) -> Any:
	try:
		return await _calculate_nutrients_impl(items_json=items_json, tool_context=tool_context)
	except (TypeError, ValueError) as exc:
		return f"Error: {exc}"


async def _calculate_average_macro_nutrient_per_day_impl(
	calculated_quantity_json: str,
	tool_context: ToolContext,
) -> Dict[str, Any]:
	"""
	Calculate average daily macro nutrients from a day/meal quantity JSON payload.

	Args:
		calculated_quantity_json: JSON object containing day-wise meals with item `index` and `qty`.
			Example:
			{
			  "day1": {"meal:one": [{"index": 5, "qty": 40.0}]},
			  "day2": {"meal:one": [{"index": 5, "qty": 60.0}]}
			}
		tool_context: ADK tool context containing prior `find_ingredient` responses.

	Returns:
		Dict containing `average_macro_nutrient_from_calculated_quantity_per_day`.
	"""
	if tool_context is None:
		raise ValueError("tool_context is required")

	items_by_day = _parse_calculated_quantity_by_day(calculated_quantity_json)
	products_by_index = _extract_products_from_history(tool_context)
	if not products_by_index:
		raise ValueError("No prior find_ingredient results found in conversation history")

	macro_keys = ("protein", "total_fat", "calories", "carbohydrates", "total_fibre")
	daily_macro_totals: Dict[str, Dict[str, float]] = {}
	missing_indices: List[int] = []

	for day_name, items in items_by_day.items():
		day_totals = {key: 0.0 for key in macro_keys}
		for item in items:
			index = int(item["index"])
			quantity_g = _to_float(item.get("quantity"), 0.0)
			product = products_by_index.get(index)
			if not product:
				missing_indices.append(index)
				continue

			nutrition_100g = _to_dict(product.get("nutrition_100g"))
			if not nutrition_100g:
				nutrition_100g = _to_dict(product.get("nutrition"))

			macro_values_per_100g = _extract_macro_values(nutrition_100g)
			for macro_key in macro_keys:
				per_100g = _to_float(macro_values_per_100g.get(macro_key), 0.0)
				if per_100g == 0:
					continue
				day_totals[macro_key] += per_100g * (quantity_g / 100.0)

		daily_macro_totals[day_name] = {
			key: round(value, 4) for key, value in day_totals.items()
		}

	day_count = max(len(daily_macro_totals), 1)
	average_macros = {
		key: round(
			sum(daily_macro_totals[day].get(key, 0.0) for day in daily_macro_totals) / day_count,
			2,
		)
		for key in macro_keys
	}

	return {
		"average_macro_nutrient_from_calculated_quantity_per_day": average_macros,
		"per_day_macro_nutrient_from_calculated_quantity": daily_macro_totals,
		"day_count": len(daily_macro_totals),
		"missing_indices": sorted(set(missing_indices)),
	}


async def calculate_average_macro_nutrient_per_day(
	calculated_quantity_json: str,
	tool_context: ToolContext,
) -> Any:
	"""
	Safely compute average daily macro nutrients from calculated quantity JSON.

	Args:
		calculated_quantity_json: JSON payload containing day/meal entries with item
			indices and quantities in grams.
			Example:
			{
			  "day1": {"meal:one": [{"index": 5, "qty": 120}, {"index": 8, "qty": 180}]},
			  "day2": {"meal:one": [{"index": 5, "qty": 140}, {"index": 8, "qty": 160}]}
			}

	Returns:
		On success, returns the computed macro summary dictionary.
		On invalid input, returns an error string in the format "Error: <message>".
	"""
	try:
		return await _calculate_average_macro_nutrient_per_day_impl(
			calculated_quantity_json=calculated_quantity_json,
			tool_context=tool_context,
		)
	except (TypeError, ValueError) as exc:
		return f"Error: {exc}"


async def sample_call_calculate_average_macro_nutrient_per_day() -> Any:
	"""Sample call demonstrating how to invoke calculate_average_macro_nutrient_per_day."""
	mock_result = [
		{"index": 5, "name": "Chicken Breast", "nutrition_100g": {"protein": 31, "fat": 3.6, "kcal": 165, "carbs": 0, "fiber": 0}},
		{"index": 8, "name": "Brown Rice", "nutrition_100g": {"protein": 2.6, "fat": 0.9, "kcal": 111, "carbohydrates": 23, "fibre": 1.8}},
	]

	mock_tool_context = SimpleNamespace(
		_invocation_context=SimpleNamespace(
			session=SimpleNamespace(
				events=[
					SimpleNamespace(
						content=SimpleNamespace(
							parts=[
								SimpleNamespace(
									function_response=SimpleNamespace(
										name="find_ingredient",
										response={"result": mock_result},
									)
								)
							]
						)
					)
				]
			)
		)
	)

	sample_calculated_quantity_json = json.dumps(
		{
			"day1": {"meal:one": [{"index": 5, "qty": 120}, {"index": 8, "qty": 180}]},
			"day2": {"meal:one": [{"index": 5, "qty": 140}, {"index": 8, "qty": 160}]},
		}
	)

	print("Sample input (calculated_quantity_json):")
	print(json.dumps(json.loads(sample_calculated_quantity_json), indent=2))

	return await calculate_average_macro_nutrient_per_day(
		calculated_quantity_json=sample_calculated_quantity_json,
		tool_context=mock_tool_context,
	)


if __name__ == "__main__":
	print(json.dumps(asyncio.run(sample_call_calculate_average_macro_nutrient_per_day()), indent=2))

