import os
import re
import json
from typing import Any, Dict, List, Optional, Tuple
from google.adk.tools.tool_context import ToolContext

try:
	from dotenv import load_dotenv
	load_dotenv()
except ImportError:
	pass

DEFAULT_DB_TABLE = "opennutrition_foods_raw"


def _get_db_table() -> str:
	return os.getenv("DB_TABLE") or DEFAULT_DB_TABLE

# Internal columns used when reading from DB before projecting tool output.
INTERNAL_COLUMNS_ORDER = [
	"name",
	"alternate_names",
	"description",
	"type",
	"serving",
	"nutrition_100g",
	"labels",
	"package_size",
	"ingredients",
	"ingredient_analysis",
	"cost",
]

# Public tool response fields.
OUTPUT_COLUMNS_ORDER = [
	"name",
	"description",
	"serving",
	"nutrition",
	"cost",
	"index",
]

DEFAULT_MAX_RESULTS = 3

# Only keep nutrition fields that downstream optimization/bounds logic consumes.
OPTIMIZATION_NUTRITION_KEY_CANDIDATES = {
	"protein": ["protein", "proteins"],
	"carbohydrates": ["carbohydrates", "carbs", "carbohydrate"],
	"total_fat": ["total_fat", "fat", "fats", "lipid"],
	"total_fibre": ["total_fibre", "total_fiber", "fiber", "fibre", "dietary_fiber", "dietary_fibre"],
	"calories": ["calories", "kcal", "energy_kcal", "energy"],
}


def _split_top_level_alternatives(pattern: str) -> List[str]:
	text = pattern.strip()
	if not text:
		return [pattern]

	# Remove one layer of wrapping parentheses if they cover the whole pattern.
	if text.startswith("(") and text.endswith(")"):
		depth = 0
		balanced = True
		for i, ch in enumerate(text):
			if ch == "(" and (i == 0 or text[i - 1] != "\\"):
				depth += 1
			elif ch == ")" and (i == 0 or text[i - 1] != "\\"):
				depth -= 1
				if depth == 0 and i != len(text) - 1:
					balanced = False
					break
		if balanced and depth == 0:
			text = text[1:-1]

	parts: List[str] = []
	buf: List[str] = []
	depth = 0
	in_char_class = False
	escaped = False
	for ch in text:
		if escaped:
			buf.append(ch)
			escaped = False
			continue
		if ch == "\\":
			buf.append(ch)
			escaped = True
			continue
		if ch == "[" and not in_char_class:
			in_char_class = True
			buf.append(ch)
			continue
		if ch == "]" and in_char_class:
			in_char_class = False
			buf.append(ch)
			continue
		if not in_char_class:
			if ch == "(":
				depth += 1
			elif ch == ")" and depth > 0:
				depth -= 1
			elif ch == "|" and depth == 0:
				parts.append("".join(buf).strip())
				buf = []
				continue
		buf.append(ch)

	parts.append("".join(buf).strip())
	parts = [p if p else pattern for p in parts]
	return parts if parts else [pattern]


def _has_top_level_alternation(pattern: str) -> bool:
	return len(_split_top_level_alternatives(pattern)) > 1


def _value_text_for_rank(value: object) -> str:
	if value is None:
		return ""
	if isinstance(value, (dict, list)):
		return json.dumps(value, ensure_ascii=False)
	return str(value)


def _pattern_rank_for_text(pattern: str, text: str) -> int:
	alternatives = _split_top_level_alternatives(pattern)
	for idx, alt in enumerate(alternatives):
		try:
			if re.search(alt, text, re.IGNORECASE):
				return idx
		except re.error:
			# Should not happen because patterns are validated earlier.
			continue
	return len(alternatives)


def _row_rank_by_filters(row: Dict[str, str], filters: List[Tuple[str, str]]) -> Tuple[int, ...]:
	ranks: List[int] = []
	for col, pattern in filters:
		if col == "name":
			name_text = _value_text_for_rank(row.get("name", ""))
			alt_text = _value_text_for_rank(row.get("alternate_names", ""))
			rank = min(_pattern_rank_for_text(pattern, name_text), _pattern_rank_for_text(pattern, alt_text))
		else:
			text = _value_text_for_rank(row.get(col, ""))
			rank = _pattern_rank_for_text(pattern, text)
		ranks.append(rank)
	return tuple(ranks)


def _sort_results_by_regex_order(rows: List[Dict[str, str]], filters: List[Tuple[str, str]]) -> List[Dict[str, str]]:
	if not rows or not filters:
		return rows
	if not any(_has_top_level_alternation(pattern) for _, pattern in filters):
		return rows
	return sorted(rows, key=lambda row: _row_rank_by_filters(row, filters))


def _is_product_record(value: object) -> bool:
	if not isinstance(value, dict):
		return False
	if "name" not in value:
		return False
	return any(k in value for k in ("cost", "nutrition", "nutrition_100g", "serving"))


def _parse_nutrition_obj(value: object) -> Dict[str, Any]:
	if isinstance(value, dict):
		return value
	if isinstance(value, str):
		text = value.strip()
		if not text:
			return {}
		try:
			parsed = json.loads(text)
			return parsed if isinstance(parsed, dict) else {}
		except Exception:
			return {}
	return {}


def _project_optimization_nutrition(value: object) -> Dict[str, Any]:
	source = _parse_nutrition_obj(value)
	projected: Dict[str, Any] = {}
	for canonical_key, candidates in OPTIMIZATION_NUTRITION_KEY_CANDIDATES.items():
		for candidate_key in candidates:
			candidate_value = source.get(candidate_key)
			if candidate_value is None:
				continue
			if isinstance(candidate_value, str) and not candidate_value.strip():
				continue
			projected[canonical_key] = candidate_value
			break
	return projected


def _project_output_row(row: Dict[str, Any]) -> Dict[str, Any]:
	nutrition_source = row.get("nutrition_100g", row.get("nutrition", ""))
	return {
		"name": row.get("name", ""),
		"description": row.get("description", ""),
		"serving": row.get("serving", ""),
		"nutrition": _project_optimization_nutrition(nutrition_source),
		"cost": row.get("cost", ""),
		"index": row.get("index"),
	}


def _count_products_in_payload(value: object) -> int:
	if _is_product_record(value):
		return 1
	if isinstance(value, list):
		return sum(_count_products_in_payload(item) for item in value)
	if isinstance(value, dict):
		return sum(_count_products_in_payload(item) for item in value.values())
	return 0


def _history_product_count(tool_context: Optional[ToolContext]) -> int:
	if tool_context is None:
		return 0

	total = 0
	session = tool_context._invocation_context.session
	for event in session.events:
		content = getattr(event, "content", None)
		parts = getattr(content, "parts", None) if content else None
		if not parts:
			continue
		for part in parts:
			function_response = getattr(part, "function_response", None)
			if not function_response:
				continue
			if getattr(function_response, "name", None) != "find_ingredient":
				continue
			payload = getattr(function_response, "response", None)
			total += _count_products_in_payload(payload)

	return total


def _coerce_optional_str(value: object) -> Optional[str]:
	if value is None:
		return None
	text = str(value).strip()
	return text if text else None


def _coerce_optional_list_of_str(value: object, field_name: str) -> Optional[List[str]]:
	if value is None:
		return None
	if not isinstance(value, list):
		raise ValueError(f"{field_name} must be a list of strings")
	coerced: List[str] = []
	for item in value:
		text = _coerce_optional_str(item)
		if text is not None:
			coerced.append(text)
	return coerced or None


def _find_ingredient_single(
	name_query: Optional[str],
	description: Optional[str],
	serving: Optional[str],
	nutrition: Optional[str],
	nutrition_ranges: Optional[List[str]],
	ingredients: Optional[str],
	ingredient_analysis: Optional[str],
	cost: Optional[str],
	cost_ranges: Optional[List[str]],
	max_results: Optional[int],
) -> List[Dict[str, str]]:
	# Prepare all filters as (column, pattern) tuples
	filters = []
	if name_query:
		filters.append(("name", name_query))
	if description:
		filters.append(("description", description))
	if serving:
		filters.append(("serving", serving))
	if nutrition:
		filters.append(("nutrition_100g", nutrition))
	if ingredients:
		filters.append(("ingredients", ingredients))
	if ingredient_analysis:
		filters.append(("ingredient_analysis", ingredient_analysis))
	if cost:
		filters.append(("cost", cost))

	# Validate all regexes (case-insensitive)
	for col, pat in filters:
		try:
			re.compile(pat, re.IGNORECASE)
		except re.error as e:
			raise ValueError(f"Invalid regex for column '{col}': {e}") from e

	nutrition_value_ranges = _parse_range_specs(nutrition_ranges, "nutrition_ranges")
	cost_value_ranges = _parse_range_specs(cost_ranges, "cost_ranges")
	_validate_named_ranges(nutrition_value_ranges, "nutrition_value_ranges")
	_validate_named_ranges(cost_value_ranges, "cost_value_ranges")

	return _find_ingredient_db_multi(
		filters=filters,
		nutrition_value_ranges=nutrition_value_ranges,
		cost_value_ranges=cost_value_ranges,
		max_results=max_results,
		table=_get_db_table(),
	)


def _find_ingredient_impl(
	ingredient_queries: List[Dict[str, Any]],
	tool_context: Optional[ToolContext] = None,
) -> List[Dict[str, Any]]:
	"""
	Query Postgres DB and return rows filtered by regex patterns for multiple ingredient queries.

	Each ingredient query object can include regex filter parameters (name_query, description,
	serving, nutrition, ingredients, ingredient_analysis, cost), combined with AND logic.
	All matching is always case-insensitive.
	Range filters are available for nutrition_100g and cost by key, and are combined with AND logic.

	Args:
		ingredient_queries: List of per-ingredient filter objects. Each item supports:
			name_query, description, serving, nutrition, nutrition_ranges,
			ingredients, ingredient_analysis, cost, cost_ranges,
			optional max_results.
			When max_results is omitted for a query, DEFAULT_MAX_RESULTS is used.

	Returns:
		A list of dictionaries with keys
		`query_index`, `query`, and `results`.

	Raises:
		ValueError: If a regex is invalid or a numeric range is invalid.
		RuntimeError: If DB connection requirements are not met.
	"""
	offset = _history_product_count(tool_context)

	if not isinstance(ingredient_queries, list):
		raise ValueError("ingredient_queries must be a list of per-ingredient query objects")

	combined: List[Dict[str, Any]] = []
	running_index = offset
	for query_index, query in enumerate(ingredient_queries):
		if not isinstance(query, dict):
			raise ValueError(f"ingredient_queries[{query_index}] must be an object")

		query_name = _coerce_optional_str(query.get("name_query"))
		query_description = _coerce_optional_str(query.get("description"))
		query_serving = _coerce_optional_str(query.get("serving"))
		query_nutrition = _coerce_optional_str(query.get("nutrition"))
		query_nutrition_ranges = _coerce_optional_list_of_str(query.get("nutrition_ranges"), "nutrition_ranges")
		query_ingredients = _coerce_optional_str(query.get("ingredients"))
		query_ingredient_analysis = _coerce_optional_str(query.get("ingredient_analysis"))
		query_cost = _coerce_optional_str(query.get("cost"))
		query_cost_ranges = _coerce_optional_list_of_str(query.get("cost_ranges"), "cost_ranges")

		query_max_results_value = query.get("max_results", DEFAULT_MAX_RESULTS)
		query_max_results: Optional[int] = DEFAULT_MAX_RESULTS
		if query_max_results_value is not None:
			try:
				query_max_results = int(query_max_results_value)
			except Exception as e:
				raise ValueError("max_results must be an integer when provided") from e

		rows = _find_ingredient_single(
			name_query=query_name,
			description=query_description,
			serving=query_serving,
			nutrition=query_nutrition,
			nutrition_ranges=query_nutrition_ranges,
			ingredients=query_ingredients,
			ingredient_analysis=query_ingredient_analysis,
			cost=query_cost,
			cost_ranges=query_cost_ranges,
			max_results=query_max_results,
		)

		for row in rows:
			row["index"] = running_index
			running_index += 1

		query_payload: Dict[str, Any] = {
			"name_query": query_name,
			"description": query_description,
			"serving": query_serving,
			"nutrition": query_nutrition,
			"nutrition_ranges": query_nutrition_ranges,
			"ingredients": query_ingredients,
			"ingredient_analysis": query_ingredient_analysis,
			"cost": query_cost,
			"cost_ranges": query_cost_ranges,
			"max_results": query_max_results,
		}
		query_payload = {k: v for k, v in query_payload.items() if v is not None}

		projected_rows = [_project_output_row(row) for row in rows]

		combined.append(
			{
				"query_index": query_index,
				"query": query_payload,
				"results": projected_rows,
			}
		)

	return combined


def find_ingredient(
	ingredient_queries: List[Dict[str, Any]],
	tool_context: Optional[ToolContext] = None,
) -> Any:
	"""Search the nutrition database for food ingredients matching the given queries.

	Executes one or more ingredient searches against the Postgres database. Each
	query can filter by name, description, serving, nutrition content, ingredient
	list, ingredient analysis, and cost using case-insensitive regex patterns.
	Range filters are also supported for nutrition and cost fields.

	Args:
		ingredient_queries: List of query objects. Each object may contain:
			- name_query (str): Regex pattern matched against the food name and
			  alternate names.
			- description (str): Regex pattern matched against the description.
			- serving (str): Regex pattern matched against the serving field.
			- nutrition (str): Regex pattern matched against the nutrition_100g
			  field.
			- nutrition_ranges (List[str]): Per-key nutrition bounds in
			  "key:min:max" format (use "none" for open bounds), e.g.
			  ["protein:20:40", "calories:none:500"].
			- ingredients (str): Regex pattern matched against the ingredients
			  list.
			- ingredient_analysis (str): Regex pattern matched against the
			  ingredient analysis field.
			- cost (str): Regex pattern matched against the cost field.
			- cost_ranges (List[str]): Per-key cost bounds in "key:min:max"
			  format.
			- max_results (int): Maximum number of results to return for this
			  query (default: 3).
			Example::

				ingredient_queries = [
					{
						"name_query": "chicken|tofu",
						"nutrition_ranges": ["protein:15:none", "calories:none:250"],
						"cost_ranges": ["price:none:5"],
						"max_results": 5,
					}
				]
	Returns:
		On success, a list of dicts, one per query, each with keys:
			- ``query_index`` (int): Position of this query in the input list.
			- ``query`` (dict): The effective query parameters used.
			- ``results`` (list): Matching food records, each containing:
				- ``name``, ``description``, ``serving``, ``nutrition``,
				  ``cost``, ``index``.
				  ``nutrition`` is projected to: protein, carbohydrates,
				  total_fat, total_fibre, calories per 100 gram of the ingredient.
				  ``index`` is a session-unique integer for referencing the
				  product in subsequent optimizer/calculator tool calls.
		On validation error, returns a string starting with ``"Error: "``.
	"""
	try:
		return _find_ingredient_impl(
			ingredient_queries=ingredient_queries,
			tool_context=tool_context,
		)
	except (TypeError, ValueError, RuntimeError) as exc:
		return f"Error: {exc}"


def _validate_named_ranges(
	ranges: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]],
	range_name: str,
) -> None:
	if not ranges:
		return
	for key, bounds in ranges.items():
		if not isinstance(bounds, tuple) or len(bounds) != 2:
			raise ValueError(f"{range_name}[{key!r}] must be a (min, max) tuple")
		lower, upper = bounds
		if lower is not None and upper is not None and lower > upper:
			raise ValueError(f"{range_name}[{key!r}] has min greater than max")


def _parse_range_specs(
	specs: Optional[List[str]],
	param_name: str,
) -> Dict[str, Tuple[Optional[float], Optional[float]]]:
	parsed: Dict[str, Tuple[Optional[float], Optional[float]]] = {}
	if not specs:
		return parsed

	def _to_float_or_none(value: str) -> Optional[float]:
		if value.lower() in {"none", "null", "na"}:
			return None
		return float(value)

	for spec in specs:
		parts = spec.split(":")
		if len(parts) != 3:
			raise ValueError(f"{param_name} entries must be 'key:min:max'. Got: {spec!r}")
		key, lower_raw, upper_raw = parts[0].strip(), parts[1].strip(), parts[2].strip()
		if not key:
			raise ValueError(f"{param_name} key cannot be empty in: {spec!r}")
		parsed[key] = (_to_float_or_none(lower_raw), _to_float_or_none(upper_raw))

	return parsed


def _get_db_conn():
	try:
		import psycopg2  # lazy import to avoid hard dependency unless DB is used
	except ImportError:
		raise RuntimeError("psycopg2 is required for DB backend. Install with: pip install psycopg2-binary")

	dsn = os.getenv("DATABASE_URL")
	if dsn:
		return psycopg2.connect(dsn)

	# Accept common aliases to make local/dev setups easier.
	host = os.getenv("PGHOST") or os.getenv("POSTGRES_HOST") or os.getenv("DB_HOST") or "localhost"
	port_raw = os.getenv("PGPORT") or os.getenv("POSTGRES_PORT") or os.getenv("DB_PORT") or "5432"
	db = os.getenv("PGDATABASE") or os.getenv("POSTGRES_DB") or os.getenv("DB_NAME") or os.getenv("USER") or "postgres"
	user = os.getenv("PGUSER") or os.getenv("POSTGRES_USER") or os.getenv("DB_USER") or os.getenv("USER") or "postgres"
	password = os.getenv("PGPASSWORD") or os.getenv("POSTGRES_PASSWORD") or os.getenv("DB_PASSWORD")

	try:
		port = int(port_raw)
	except ValueError as e:
		raise RuntimeError(f"Invalid Postgres port: {port_raw!r}") from e

	try:
		return psycopg2.connect(host=host, port=port, dbname=db, user=user, password=password)
	except Exception as e:
		raise RuntimeError(
			f"Could not connect to Postgres using host={host!r}, port={port}, dbname={db!r}, user={user!r}. "
			"Set DATABASE_URL or PGHOST/PGPORT/PGDATABASE/PGUSER/PGPASSWORD."
		) from e


def _table_columns(conn, table: str) -> List[str]:
	with conn.cursor() as cur:
		cur.execute(
			"""
			SELECT column_name
			FROM information_schema.columns
			WHERE table_schema = 'public' AND table_name = %s
			ORDER BY ordinal_position
			""",
			(table,),
		)
		return [r[0] for r in cur.fetchall()]


def _table_column_types(conn, table: str) -> Dict[str, str]:
	with conn.cursor() as cur:
		cur.execute(
			"""
			SELECT column_name, data_type
			FROM information_schema.columns
			WHERE table_schema = 'public' AND table_name = %s
			""",
			(table,),
		)
		return {r[0]: r[1] for r in cur.fetchall()}


def _find_ingredient_db_multi(
	filters: List[Tuple[str, str]],
	nutrition_value_ranges: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]],
	cost_value_ranges: Optional[Dict[str, Tuple[Optional[float], Optional[float]]]],
	max_results: int,
	table: str,
) -> List[Dict[str, str]]:
	conn = _get_db_conn()
	try:
		cols = _table_columns(conn, table)
		col_types = _table_column_types(conn, table)
		colset = set(cols)
		if "name" not in colset:
			raise ValueError("'name' column not found in DB table")
		selected_cols = [c for c in INTERNAL_COLUMNS_ORDER if c in colset]

		def _regex_expr(col_name: str) -> str:
			# Postgres regex operators do not work on json/jsonb directly.
			if col_types.get(col_name) in {"json", "jsonb"}:
				return f'"{col_name}"::text'
			return f'"{col_name}"'

		where_clauses = []
		params: List[object] = []
		for col, pat in filters:
			if col == "name":
				# Special: match on name or alternate_names
				clause = f'({_regex_expr("name")} ~* %s'
				params.append(pat)
				if "alternate_names" in colset:
					clause += f' OR {_regex_expr("alternate_names")} ~* %s'
					params.append(pat)
				clause += ")"
				where_clauses.append(clause)
			else:
				where_clauses.append(f'{_regex_expr(col)} ~* %s')
				params.append(pat)

		for key, (lower, upper) in (nutrition_value_ranges or {}).items():
			pattern = f'"{key}"\\s*:\\s*"?([0-9]+(?:\\.[0-9]+)?)"?'
			if lower is not None and "nutrition_100g" in colset:
				where_clauses.append("substring(\"nutrition_100g\"::text FROM %s)::double precision >= %s")
				params.append(pattern)
				params.append(lower)
			if upper is not None and "nutrition_100g" in colset:
				where_clauses.append("substring(\"nutrition_100g\"::text FROM %s)::double precision <= %s")
				params.append(pattern)
				params.append(upper)

		for key, (lower, upper) in (cost_value_ranges or {}).items():
			pattern = f'"{key}"\\s*:\\s*"?([0-9]+(?:\\.[0-9]+)?)"?'
			if lower is not None and "cost" in colset:
				where_clauses.append("substring(\"cost\"::text FROM %s)::double precision >= %s")
				params.append(pattern)
				params.append(lower)
			if upper is not None and "cost" in colset:
				where_clauses.append("substring(\"cost\"::text FROM %s)::double precision <= %s")
				params.append(pattern)
				params.append(upper)

		where = " AND ".join(where_clauses) if where_clauses else "TRUE"
		select_list = ", ".join([f'"{c}"' for c in selected_cols]) if selected_cols else '"name"'
		has_alt_order = any(_has_top_level_alternation(pattern) for _, pattern in filters)
		if has_alt_order:
			fetch_limit = min(max(max_results * 50, max_results), 5000)
		else:
			fetch_limit = max_results

		sql = f'SELECT {select_list} FROM "{table}" WHERE {where} LIMIT %s'
		params.append(fetch_limit)

		results: List[Dict[str, str]] = []
		with conn.cursor() as cur:
			cur.execute(sql, params)
			rows = cur.fetchall()
			for r in rows:
				record = {}
				if selected_cols:
					for i, c in enumerate(selected_cols):
						record[c] = r[i] if r[i] is not None else ""
				else:
					record["name"] = r[0] if r and r[0] is not None else ""
				results.append(record)
		results = _sort_results_by_regex_order(results, filters)
		return results[:max_results]
	finally:
		conn.close()


if __name__ == "__main__":
	import argparse

	def _pretty_json(value: object) -> str:
		return json.dumps(value, indent=2, ensure_ascii=False)

	parser = argparse.ArgumentParser(description="Filter dataset by multiple columns (Postgres DB)")
	parser.add_argument("--name", dest="name", help="Regex for name column")
	parser.add_argument("--description", dest="description", help="Regex for description column")
	parser.add_argument("--serving", dest="serving", help="Regex for serving column")
	parser.add_argument("--nutrition", dest="nutrition", help="Regex for nutrition_100g column")
	parser.add_argument("--ingredients", dest="ingredients", help="Regex for ingredients column")
	parser.add_argument("--ingredient-analysis", dest="ingredient_analysis", help="Regex for ingredient_analysis column")
	parser.add_argument("--cost", dest="cost", help="Regex for cost column")
	parser.add_argument(
		"--nutrition-range",
		dest="nutrition_ranges",
		action="append",
		help="Per-key nutrition range. Repeatable format: key:min:max (use 'none' for open bounds), e.g. --nutrition-range protein:20:40",
	)
	parser.add_argument(
		"--cost-range",
		dest="cost_ranges",
		action="append",
		help="Per-key cost range. Repeatable format: key:min:max (use 'none' for open bounds), e.g. --cost-range price:2:10",
	)
	parser.add_argument("--limit", dest="limit", type=int, default=3, help="Max results to return")

	args = parser.parse_args()
	query_payload = {
		"name_query": args.name or "apple",
		"description": args.description,
		"serving": args.serving,
		"nutrition": args.nutrition,
		"nutrition_ranges": args.nutrition_ranges,
		"ingredients": args.ingredients,
		"ingredient_analysis": args.ingredient_analysis,
		"cost": args.cost,
		"cost_ranges": args.cost_ranges,
		"max_results": args.limit,
	}
	query_payload = {k: v for k, v in query_payload.items() if v is not None}
	tool_input = {"ingredient_queries": [query_payload]}

	try:
		print("find_ingredient input:")
		print(_pretty_json(tool_input))

		grouped = find_ingredient(**tool_input)

		print("\nfind_ingredient output:")
		print(_pretty_json(grouped))
	except Exception as e:
		print(f"Error: {e}")


