#!/usr/bin/env python3
"""
Load OpenNutrition TSV into Postgres with structured JSON preserved.

Usage:
.venv/bin/python code/datasets/load_opennutrition_tsv.py \
    --tsv code/datasets/opennutrition-dataset-2025/opennutrition_foods.tsv \
    --table opennutrition_foods_raw

Key behavior:
  - Parses TSV with csv quoting rules so embedded "" becomes a single ".
  - Stores dictionary/array columns as JSONB (not doubly-escaped TEXT).
  - Upserts by `id` in batches.
"""
import argparse
import csv
import json
import os
from typing import Dict, Iterable, List, Optional

import psycopg2
from psycopg2.extras import Json, execute_values

try:
    from dotenv import load_dotenv
except Exception:
    load_dotenv = None


JSON_COLUMNS = {
    "alternate_names",
    "source",
    "serving",
    "nutrition_100g",
    "labels",
    "package_size",
    "ingredient_analysis"
}


ALL_COLUMNS = [
    "id",
    "name",
    "alternate_names",
    "description",
    "type",
    "source",
    "serving",
    "nutrition_100g",
    "ean_13",
    "labels",
    "package_size",
    "ingredients",
    "ingredient_analysis",
]


def load_env() -> None:
    if load_dotenv is None:
        return
    here = os.path.dirname(__file__)
    dotenv_path = os.path.abspath(os.path.join(here, "..", ".env"))
    if os.path.exists(dotenv_path):
        load_dotenv(dotenv_path)
    else:
        load_dotenv()


def get_conn(dbname: Optional[str] = None):
    dsn = os.getenv("DATABASE_URL")
    if dsn:
        return psycopg2.connect(dsn)
    host = os.getenv("PGHOST", "localhost")
    port = int(os.getenv("PGPORT", "5432"))
    db = dbname or os.getenv("PGDATABASE")
    user = os.getenv("PGUSER")
    password = os.getenv("PGPASSWORD")
    if not (db and user):
        raise SystemExit("PGDATABASE and PGUSER (and optionally PGPASSWORD) must be set or use DATABASE_URL")
    return psycopg2.connect(host=host, port=port, dbname=db, user=user, password=password)


def ensure_database_exists(db: str) -> None:
    """Create the database if it doesn't already exist."""
    conn = None
    try:
        # Connect to the default 'postgres' database to create the target database
        conn = get_conn(dbname="postgres")
        conn.autocommit = True
        with conn.cursor() as cur:
            cur.execute(f"SELECT 1 FROM pg_database WHERE datname = %s;", (db,))
            if not cur.fetchone():
                cur.execute(f'CREATE DATABASE "{db}";')
                print(f"Created database '{db}'.")
            else:
                print(f"Database '{db}' already exists.")
    finally:
        if conn:
            conn.close()


def create_raw_table(conn, table: str) -> None:
    with conn.cursor() as cur:
        has_pg_trgm = True
        try:
            cur.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
        except psycopg2.Error:
            conn.rollback()
            has_pg_trgm = False
        cur.execute(
            f'''
            CREATE TABLE IF NOT EXISTS "{table}" (
              id TEXT PRIMARY KEY,
              name TEXT,
              alternate_names JSONB,
              description TEXT,
              type TEXT,
              source JSONB,
              serving JSONB,
              nutrition_100g JSONB,
              ean_13 TEXT,
              labels JSONB,
              package_size JSONB,
              ingredients TEXT,
              ingredient_analysis JSONB
            );
            '''
        )
        if has_pg_trgm:
            cur.execute(f'CREATE INDEX IF NOT EXISTS {table}_name_trgm_idx ON "{table}" USING gin (name gin_trgm_ops);')
        cur.execute(f'CREATE INDEX IF NOT EXISTS {table}_type_idx ON "{table}" (type);')
        cur.execute(f'CREATE INDEX IF NOT EXISTS {table}_ean_idx ON "{table}" (ean_13);')
        cur.execute(f'CREATE INDEX IF NOT EXISTS {table}_nutrition_gin_idx ON "{table}" USING gin (nutrition_100g jsonb_path_ops);')
    conn.commit()


def _empty_to_none(value: Optional[str]) -> Optional[str]:
    if value is None:
        return None
    stripped = value.strip()
    return stripped if stripped else None


def _parse_json_cell(raw: Optional[str]):
    raw = _empty_to_none(raw)
    if raw is None:
        return None
    try:
        return json.loads(raw)
    except json.JSONDecodeError:
        # Keep malformed payload as JSON string rather than failing the entire import.
        return raw


def _row_to_record(row: Dict[str, str]) -> Dict[str, object]:
    record: Dict[str, object] = {}
    for col in ALL_COLUMNS:
        value = row.get(col)
        if col in JSON_COLUMNS:
            record[col] = _parse_json_cell(value)
        else:
            record[col] = _empty_to_none(value)
    return record


def iter_records(tsv_path: str) -> Iterable[Dict[str, object]]:
    with open(tsv_path, "r", encoding="utf-8", newline="") as f:
        reader = csv.DictReader(f, delimiter="\t")
        if reader.fieldnames:
            reader.fieldnames = [name.lstrip("\ufeff") for name in reader.fieldnames]
        for row in reader:
            yield _row_to_record(row)


def upsert_batch(conn, table: str, batch: List[Dict[str, object]]) -> None:
    if not batch:
        return

    values = []
    for rec in batch:
        values.append(
            (
                rec["id"],
                rec["name"],
                Json(rec["alternate_names"]) if rec["alternate_names"] is not None else None,
                rec["description"],
                rec["type"],
                Json(rec["source"]) if rec["source"] is not None else None,
                Json(rec["serving"]) if rec["serving"] is not None else None,
                Json(rec["nutrition_100g"]) if rec["nutrition_100g"] is not None else None,
                rec["ean_13"],
                Json(rec["labels"]) if rec["labels"] is not None else None,
                Json(rec["package_size"]) if rec["package_size"] is not None else None,
                rec["ingredients"],
                Json(rec["ingredient_analysis"]) if rec["ingredient_analysis"] is not None else None,
            )
        )

    insert_sql = f'''
        INSERT INTO "{table}" (
          id, name, alternate_names, description, type, source,
          serving, nutrition_100g, ean_13, labels, package_size,
                    ingredients, ingredient_analysis
        ) VALUES %s
        ON CONFLICT (id) DO UPDATE SET
          name = EXCLUDED.name,
          alternate_names = EXCLUDED.alternate_names,
          description = EXCLUDED.description,
          type = EXCLUDED.type,
          source = EXCLUDED.source,
          serving = EXCLUDED.serving,
          nutrition_100g = EXCLUDED.nutrition_100g,
          ean_13 = EXCLUDED.ean_13,
          labels = EXCLUDED.labels,
          package_size = EXCLUDED.package_size,
          ingredients = EXCLUDED.ingredients,
            ingredient_analysis = EXCLUDED.ingredient_analysis;
    '''

    with conn.cursor() as cur:
        execute_values(cur, insert_sql, values, page_size=len(values))
    conn.commit()


def load_tsv(conn, table: str, tsv_path: str, batch_size: int) -> int:
    loaded = 0
    batch: List[Dict[str, object]] = []
    for rec in iter_records(tsv_path):
        if not rec.get("id"):
            continue
        batch.append(rec)
        if len(batch) >= batch_size:
            upsert_batch(conn, table, batch)
            loaded += len(batch)
            batch = []

    if batch:
        upsert_batch(conn, table, batch)
        loaded += len(batch)

    return loaded


def main():
    parser = argparse.ArgumentParser(description="Load OpenNutrition TSV into Postgres (typed JSONB staging table)")
    parser.add_argument("--tsv", required=True, help="Path to opennutrition TSV file")
    parser.add_argument("--table", default="opennutrition_foods_raw", help="Destination table name")
    parser.add_argument("--batch-size", type=int, default=2000, help="Rows per batch upsert")
    parser.add_argument("--truncate", action="store_true", help="Truncate destination table before loading")
    args = parser.parse_args()

    if not os.path.exists(args.tsv):
        raise SystemExit(f"TSV not found: {args.tsv}")

    load_env()
    db = os.getenv("PGDATABASE")
    if not db:
        raise SystemExit("PGDATABASE must be set")
    ensure_database_exists(db)
    conn = get_conn()
    try:
        create_raw_table(conn, args.table)
        if args.truncate:
            with conn.cursor() as cur:
                cur.execute(f'TRUNCATE TABLE "{args.table}";')
            conn.commit()

        count = load_tsv(conn, args.table, args.tsv, max(1, args.batch_size))
        print(f"Loaded {count} rows from {args.tsv} into table '{args.table}'.")
    finally:
        conn.close()


if __name__ == "__main__":
    main()
