#!/usr/bin/env python3
"""Remove nutritionally invalid rows from the opennutrition_foods_raw Postgres table.

Applies the same sanity rules used by `utils/ingredient_tool._passes_nutrition_sanity_filters`:
  - nutrition_100g payload must be present with energy, fat, carbohydrates, and protein fields
  - all values must be non-negative
  - energy must not exceed MAX_ENERGY_KCAL_100G
  - each macro must not exceed MAX_SINGLE_MACRO_G_100G
  - macro sum must not exceed MAX_MACRO_SUM_G_100G
  - stated energy must be within CALORIES_TOLERANCE_RATIO of the calculated calorie total

Usage:
  python code/datasets/pre_process_opennutrition.py
  python code/datasets/pre_process_opennutrition.py --table opennutrition_foods_raw --dry-run
"""

from __future__ import annotations

import argparse
import math
import os
import re
import sys
from typing import Any

import psycopg2

try:
    from dotenv import load_dotenv
except Exception:
    load_dotenv = None


# ---------------------------------------------------------------------------
# Nutrition sanity thresholds
# ---------------------------------------------------------------------------

MAX_ENERGY_KCAL_100G = 900.0
MAX_SINGLE_MACRO_G_100G = 100.0
MAX_MACRO_SUM_G_100G = 100.0
CALORIES_TOLERANCE_RATIO = 0.4

NUTRITION_KEY_CANDIDATES: dict[str, list[str]] = {
    "energy_kcal_100g": [
        "energy-kcal_100g",
        "energy_kcal_100g",
        "energy_kcal",
        "energy-kcal",
        "calories",
        "kcal",
        "energy",
    ],
    "fat_100g": ["fat_100g", "fat", "fats", "total_fat", "total-fat_100g", "lipid"],
    "carbohydrates_100g": ["carbohydrates_100g", "carbohydrates", "carbs", "carbohydrate"],
    "proteins_100g": ["proteins_100g", "proteins", "protein"],
}


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _to_float(value: Any) -> float | None:
    if isinstance(value, bool):
        return None
    if isinstance(value, (int, float)):
        parsed = float(value)
        return parsed if math.isfinite(parsed) else None
    if isinstance(value, str):
        text = value.strip()
        if not text:
            return None
        match = re.search(r"[-+]?\d*\.?\d+", text)
        if not match:
            return None
        try:
            parsed = float(match.group(0))
            return parsed if math.isfinite(parsed) else None
        except ValueError:
            return None
    return None


def _read_nutrition_value(nutrition_obj: dict[str, Any], key_aliases: list[str]) -> float | None:
    for alias in key_aliases:
        if alias not in nutrition_obj:
            continue
        parsed = _to_float(nutrition_obj.get(alias))
        if parsed is not None:
            return parsed
    return None


def _nutrition_failure_reason(nutrition_obj: dict[str, Any] | None) -> str | None:
    if not nutrition_obj or not isinstance(nutrition_obj, dict):
        return "missing nutrition payload"

    energy = _read_nutrition_value(nutrition_obj, NUTRITION_KEY_CANDIDATES["energy_kcal_100g"])
    fat = _read_nutrition_value(nutrition_obj, NUTRITION_KEY_CANDIDATES["fat_100g"])
    carbs = _read_nutrition_value(nutrition_obj, NUTRITION_KEY_CANDIDATES["carbohydrates_100g"])
    protein = _read_nutrition_value(nutrition_obj, NUTRITION_KEY_CANDIDATES["proteins_100g"])

    missing = [
        name
        for name, val in (("energy", energy), ("fat", fat), ("carbohydrates", carbs), ("protein", protein))
        if val is None
    ]
    if missing:
        return f"missing required nutrition fields: {', '.join(missing)}"

    assert energy is not None and fat is not None and carbs is not None and protein is not None

    if energy < 0 or fat < 0 or carbs < 0 or protein < 0:
        return "contains negative nutrition values"

    if energy > MAX_ENERGY_KCAL_100G:
        return f"energy {energy:g} exceeds {MAX_ENERGY_KCAL_100G:g} kcal/100g"

    if fat > MAX_SINGLE_MACRO_G_100G or carbs > MAX_SINGLE_MACRO_G_100G or protein > MAX_SINGLE_MACRO_G_100G:
        return f"single macro exceeds {MAX_SINGLE_MACRO_G_100G:g} g/100g"

    macro_sum = fat + carbs + protein
    if macro_sum > MAX_MACRO_SUM_G_100G:
        return f"macro sum {macro_sum:g} exceeds {MAX_MACRO_SUM_G_100G:g} g/100g"

    calculated_calories = 9.0 * fat + 4.0 * carbs + 4.0 * protein

    # Skip calorie-mismatch check for zero-macro products (e.g. artificial sweeteners).
    if calculated_calories <= 0:
        return None

    lower = calculated_calories * (1.0 - CALORIES_TOLERANCE_RATIO)
    upper = calculated_calories * (1.0 + CALORIES_TOLERANCE_RATIO)
    if energy < lower or energy > upper:
        return (
            f"energy {energy:g} kcal/100g outside [{lower:g}, {upper:g}] "
            f"(calculated {calculated_calories:g})"
        )

    return None


# ---------------------------------------------------------------------------
# Database helpers
# ---------------------------------------------------------------------------

def load_env() -> None:
    if load_dotenv is None:
        return
    here = os.path.dirname(__file__)
    dotenv_path = os.path.abspath(os.path.join(here, "..", ".env"))
    if os.path.exists(dotenv_path):
        load_dotenv(dotenv_path)
    else:
        load_dotenv()


def get_conn():
    dsn = os.getenv("DATABASE_URL")
    if dsn:
        return psycopg2.connect(dsn)
    host = os.getenv("PGHOST", "localhost")
    port = int(os.getenv("PGPORT", "5432"))
    db = os.getenv("PGDATABASE")
    user = os.getenv("PGUSER")
    password = os.getenv("PGPASSWORD")
    if not (db and user):
        raise SystemExit(
            "Set DATABASE_URL or PGDATABASE + PGUSER (+ optional PGPASSWORD/PGHOST/PGPORT)"
        )
    return psycopg2.connect(host=host, port=port, dbname=db, user=user, password=password)


# ---------------------------------------------------------------------------
# Core logic
# ---------------------------------------------------------------------------

def collect_invalid_ids(conn, table: str, batch_size: int = 5000) -> list[tuple[str, str, str]]:
    """Return (id, name, reason) for every row whose nutrition_100g fails sanity checks."""
    invalid: list[tuple[str, str, str]] = []
    offset = 0
    with conn.cursor() as cur:
        while True:
            cur.execute(
                f'SELECT id, name, nutrition_100g FROM "{table}" ORDER BY id LIMIT %s OFFSET %s',
                (batch_size, offset),
            )
            rows = cur.fetchall()
            if not rows:
                break
            for row_id, name, nutrition_obj in rows:
                reason = _nutrition_failure_reason(nutrition_obj)
                if reason is not None:
                    invalid.append((row_id, name or "", reason))
            offset += len(rows)
            print(f"  scanned {offset} rows, found {len(invalid)} invalid so far …", end="\r", flush=True)
    print()
    return invalid


def delete_rows(conn, table: str, ids: list[str], batch_size: int = 2000) -> int:
    deleted = 0
    with conn.cursor() as cur:
        for start in range(0, len(ids), batch_size):
            chunk = ids[start : start + batch_size]
            cur.execute(
                f'DELETE FROM "{table}" WHERE id = ANY(%s)',
                (chunk,),
            )
            deleted += cur.rowcount
    conn.commit()
    return deleted


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------

def main() -> int:
    parser = argparse.ArgumentParser(
        description="Remove nutritionally invalid rows from the opennutrition_foods_raw table."
    )
    parser.add_argument("--table", default="opennutrition_foods_raw", help="Target table name.")
    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Report invalid rows without deleting them.",
    )
    parser.add_argument("--batch-size", type=int, default=5000, help="Rows per scan batch.")
    args = parser.parse_args()

    load_env()
    conn = get_conn()
    try:
        print(f"Scanning table '{args.table}' for nutrition sanity violations …")
        invalid = collect_invalid_ids(conn, args.table, batch_size=args.batch_size)

        if not invalid:
            print("No invalid rows found.")
            return 0

        print(f"\nFound {len(invalid)} invalid row(s).")
        for row_id, name, reason in invalid[:20]:
            print(f"  id={row_id!r}  name={name!r}  reason={reason}")
        if len(invalid) > 20:
            print(f"  … and {len(invalid) - 20} more")

        if args.dry_run:
            print("\nDry-run mode — no rows deleted.")
            return 0

        ids_to_delete = [row_id for row_id, _, __ in invalid]
        deleted = delete_rows(conn, args.table, ids_to_delete)
        print(f"\nDeleted {deleted} row(s) from '{args.table}'.")
    finally:
        conn.close()

    return 0


if __name__ == "__main__":
    try:
        raise SystemExit(main())
    except BrokenPipeError:
        raise SystemExit(0)
