import os
import sys
import json
import glob
import base64
import mimetypes
import argparse
import re
from typing import List, Dict, Any

from PIL import Image
from openai import OpenAI

try:
    from tqdm.auto import tqdm
except Exception:  # pragma: no cover
    tqdm = None

def _mime_for(path: str) -> str:
    mt, _ = mimetypes.guess_type(path)
    # Fallback to PNG if unknown
    return mt or "image/png"

def _maybe_resize(image_path: str, max_long_edge: int | None) -> bytes:
    """
    Optionally downscale very large images to reduce latency/cost.
    Returns raw PNG bytes (so we also normalize format for consistency).
    """
    with Image.open(image_path) as im:
        im = im.convert("RGB")
        if max_long_edge is None:
            from io import BytesIO
            buf = BytesIO()
            im.save(buf, format="PNG", optimize=True)
            return buf.getvalue()

        w, h = im.size
        long_edge = max(w, h)
        if long_edge > max_long_edge:
            scale = max_long_edge / float(long_edge)
            new_size = (int(w * scale), int(h * scale))
            im = im.resize(new_size, Image.LANCZOS)

        from io import BytesIO
        buf = BytesIO()
        im.save(buf, format="PNG", optimize=True)
        return buf.getvalue()

def _path_to_data_url(image_path: str, max_long_edge: int | None) -> str:
    """
    Reads a local image file and returns a data URL for use in image_url.
    Uses optional resizing to control token usage.
    """
    mime = _mime_for(image_path)
    if max_long_edge is not None:
        # Normalize to PNG when resizing
        data = _maybe_resize(image_path, max_long_edge)
        mime = "image/png"
    else:
        with open(image_path, "rb") as f:
            data = f.read()
    b64 = base64.b64encode(data).decode("utf-8")
    return f"data:{mime};base64,{b64}"

def _make_system_prompt(facts_min: int, facts_max: int) -> str:
    # Enforce self-contained, declarative sentences + source type.
    return f"""
You extract concise, verifiable FACTS from a single document page image.

STRICT requirements:
- Output a JSON object with exactly one key: "facts": {{ "fact": string, "source": "text" | "figure" }}[].
- Produce {facts_min}–{facts_max} atomic facts when possible; fewer only if the page has little content.
- Each fact MUST be a single self-contained declarative sentence with its subject and any units/dates, and MUST end with a period.
- Avoid deictic/layout words and vague pronouns: do NOT use "this", "that", "these", "those", "above", "below", "left", "right", "the following", "see figure/table".
- A fact is "text" if it comes from paragraphs, headings, lists, tables, or captions.
- A fact is "figure" if it comes from a chart/diagram/photo/graphic, including labels/legends/axes inside the graphic.
- Do not invent information; only state what appears on the page.
- If the page has no legible content, return {{ "facts": [] }}.
- Use the page’s language (do not translate).
""".strip()

def _make_user_instruction(facts_min: int, facts_max: int) -> str:
    return f"""
Extract {facts_min}–{facts_max} self-contained facts present on this page.
For each fact, also label its source as "text" or "figure" using the rules above.
Return ONLY JSON like:
{{"facts":[{{"fact":"...", "source":"text"}}, {{"fact":"...", "source":"figure"}}]}}
""".strip()

def _expand_inputs(inputs: List[str], recursive: bool) -> List[str]:
    """
    Accept files, directories, and glob patterns.
    """
    exts = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp", ".webp", ".gif"}
    files: list[str] = []

    for item in inputs:
        # Glob first (handles both explicit globs and plain paths)
        matched = glob.glob(item)
        if not matched:
            matched = [item]

        for p in matched:
            if os.path.isdir(p):
                if recursive:
                    for root, _, fnames in os.walk(p):
                        for fn in fnames:
                            if os.path.splitext(fn.lower())[1] in exts:
                                files.append(os.path.join(root, fn))
                else:
                    for fn in os.listdir(p):
                        if os.path.splitext(fn.lower())[1] in exts:
                            files.append(os.path.join(p, fn))
            else:
                if os.path.splitext(p.lower())[1] in exts and os.path.isfile(p):
                    files.append(p)

    # Deduplicate preserving order
    seen = set()
    deduped = []
    for f in files:
        if f not in seen:
            seen.add(f)
            deduped.append(f)
    return deduped

def _clean_sentence(s: str) -> str:
    s = (s or "").strip().lstrip("-•·* ").replace("\n", " ")
    s = re.sub(r"\s{2,}", " ", s).strip(' "\'[]')
    if s and not re.search(r"[.!?]$", s):
        s += "."
    return s

def _coerce_and_clean_items(items: Any) -> List[Dict[str, str]]:
    """
    Ensure we return a list of {"fact": str, "source": "text"|"figure"}.
    If the model returns strings, coerce to {"fact": "...", "source": "text"}.
    """
    out: List[Dict[str, str]] = []
    if not isinstance(items, list):
        return out
    for it in items:
        if isinstance(it, dict):
            fact = _clean_sentence(str(it.get("fact") or it.get("text") or ""))
            source = str(it.get("source") or "").lower().strip()
        else:
            fact = _clean_sentence(str(it))
            source = "text"
        if not fact:
            continue
        if source not in ("text", "figure"):
            source = "text"  # conservative fallback
        out.append({"fact": fact, "source": source})
    return out

def extract_facts_from_images(
    image_paths: List[str],
    *,
    model: str,
    detail: str,
    temperature: float,
    max_long_edge: int | None,
    facts_min: int,
    facts_max: int,
    client: OpenAI,
    show_progress: bool = True,
) -> Dict[str, List[Dict[str, str]]]:
    """
    For each local image path, call o4-mini (vision) and return a mapping:
        { "<path>": [{"fact": "...", "source": "text|figure"}, ...], ... }
    """
    results: Dict[str, List[Dict[str, str]]] = {}

    system_prompt = _make_system_prompt(facts_min, facts_max)
    user_instruction = _make_user_instruction(facts_min, facts_max)

    use_bar = bool(show_progress and tqdm is not None)
    successes = 0
    failures = 0

    iterator = (
        tqdm(image_paths, total=len(image_paths), unit="img", desc="Processing pages")
        if use_bar else image_paths
    )

    for path in iterator:
        try:
            data_url = _path_to_data_url(path, max_long_edge)
            resp = client.chat.completions.create(
                model=model,
                temperature=temperature,
                response_format={"type": "json_object"},
                messages=[
                    {"role": "system", "content": system_prompt},
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": user_instruction},
                            {
                                "type": "image_url",
                                "image_url": {"url": data_url, "detail": detail},
                            },
                        ],
                    },
                ],
            )

            content = resp.choices[0].message.content
            data = json.loads(content) if content else {"facts": []}
            items = _coerce_and_clean_items(data.get("facts", []))
            results[path] = items
            successes += 1

        except Exception as e:
            results[path] = []
            failures += 1
            print(f"[WARN] Failed on {path}: {e}", file=sys.stderr)

        # Update postfix after each item
        if use_bar:
            iterator.set_postfix_str(f"ok={successes} fail={failures}")

    return results

def main():
    parser = argparse.ArgumentParser(
        description="Extract self-contained facts + source (text|figure) per page using o4-mini."
    )
    parser.add_argument(
        "inputs",
        nargs="+",
        help="Image files, directories, or glob patterns (e.g., scans/*.png).",
    )
    parser.add_argument(
        "-r",
        "--recursive",
        action="store_true",
        help="Recurse into directories.",
    )
    parser.add_argument(
        "-o",
        "--out",
        default="-",
        help="Output JSON file path (default: '-' for stdout).",
    )
    parser.add_argument(
        "--model",
        default="o4-mini",
        help="Model name (default: o4-mini).",
    )
    parser.add_argument(
        "--detail",
        choices=["low", "high", "auto"],
        default="high",
        help="Vision detail level (default: high).",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Sampling temperature (default: 0.0).",
    )
    parser.add_argument(
        "--max-long-edge",
        type=int,
        default=2000,
        help="Resize so the longer edge == this many px. Use 0 to disable (send original). Default: 2000.",
    )
    parser.add_argument(
        "--facts-min",
        type=int,
        default=1,
        help="Minimum facts to request (soft guidance). Default: 1.",
    )
    parser.add_argument(
        "--facts-max",
        type=int,
        default=10,
        help="Maximum facts to request (soft guidance). Default: 10.",
    )
    parser.add_argument(
        "--api-key",
        default=os.getenv("OPENAI_API_KEY"),
        help="OpenAI API key (default: from OPENAI_API_KEY env var).",
    )
    grp = parser.add_mutually_exclusive_group()
    grp.add_argument(
        "--progress",
        dest="progress",
        action="store_true",
        help="Show a progress bar (default).",
    )
    grp.add_argument(
        "--no-progress",
        dest="progress",
        action="store_false",
        help="Disable the progress bar.",
    )
    parser.set_defaults(progress=True)

    args = parser.parse_args()

    if args.api_key is None or not args.api_key.strip():
        print("ERROR: Missing OpenAI API key. Set --api-key or OPENAI_API_KEY.", file=sys.stderr)
        sys.exit(2)

    if args.facts_min < 0 or args.facts_max < 0:
        print("ERROR: facts-min and facts-max must be non-negative.", file=sys.stderr)
        sys.exit(2)
    if args.facts_min > args.facts_max:
        print("NOTE: facts-min > facts-max, swapping values for you.", file=sys.stderr)
        args.facts_min, args.facts_max = args.facts_max, args.facts_min

    max_long_edge: int | None = None if args.max_long_edge == 0 else args.max_long_edge

    image_paths = _expand_inputs(args.inputs, recursive=args.recursive)
    if not image_paths:
        print("ERROR: No images found from the given inputs.", file=sys.stderr)
        sys.exit(2)

    client = OpenAI(api_key=args.api_key)

    results = extract_facts_from_images(
        image_paths=image_paths,
        model=args.model,
        detail=args.detail,
        temperature=args.temperature,
        max_long_edge=max_long_edge,
        facts_min=args.facts_min,
        facts_max=args.facts_max,
        client=client,
        show_progress=args.progress,
    )

    # Write output
    if args.out == "-" or args.out.lower() == "stdout":
        print(json.dumps(results, ensure_ascii=False, indent=2))
    else:
        os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
        with open(args.out, "w", encoding="utf-8") as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
        print(f"Wrote: {os.path.abspath(args.out)}")

if __name__ == "__main__":
    main()
