#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
openrouter_predict_demo.py

Use a vision-capable model on OpenRouter to predict pos / angle / size for a tangram piece
from a PNG (outline) while keeping other fields from an existing JSON unchanged.
It then writes a predict JSON and (optionally) evaluates basic errors vs. the ground-truth JSON.

Requirements:
  pip install requests pillow numpy matplotlib

Env:
  export OPENROUTER_API_KEY='sk-or-...'

Example:
  python openrouter_predict_demo.py \
    --model openai/gpt-4o-mini \
    --mode pos \
    --image /path/to/page7-72_piece5.png \
    --gt_json /path/to/page7-72_piece5.json \
    --out_json /tmp/page7-72_piece5_pred_pos.json

Batch example:
  python openrouter_predict_demo.py \
    --model openai/gpt-4o-mini \
    --mode all \
    --in_dir /path/to/pngs \
    --gt_dir /path/to/jsons \
    --out_dir /tmp/preds

Notes:
- Choose a vision-capable model on OpenRouter (e.g., openai/gpt-4o, google/gemini-1.5-flash, qwen/qwen2.5-vl-7b-instruct if available).
- This script uses the OpenAI-compatible Chat Completions schema supported by OpenRouter.
"""

import os, re, json, math, base64, argparse, glob
from dataclasses import dataclass
from typing import Dict, Any, Optional, Tuple, List
import requests
import numpy as np
import subprocess, shutil, tempfile
import sys  # for invoking external geometry script
import matplotlib
from PIL import Image, ImageDraw, ImageChops

# --- Filtering helpers -------------------------------------------------------
def run_geometry_render(gt_json_path: str, pred_json_path: str, out_dir: str, geometry_py: str):
    """
    Render GT, PRED, and OVERLAY images using the external geometry script if possible;
    otherwise, fall back to the internal renderer.
    Returns (gtr, prd, ovl): paths to GT, PRED, and OVERLAY images.
    """
    import os
    os.makedirs(out_dir, exist_ok=True)
    # Output filenames as specified
    gtr = os.path.join(out_dir, "gt_render.png")
    prd = os.path.join(out_dir, "pred_render.png")
    ovl = os.path.join(out_dir, "overlay_json_vs_pred.png")
    stem = "render"
    try:
        # 1. GT render
        cmd_gt = [sys.executable, geometry_py, "--ann", gt_json_path, "--save_png", gtr, "--no_show"]
        try:
            res_gt = subprocess.run(cmd_gt, check=True, capture_output=True, text=True)
            if os.environ.get("OR_VERBOSE_GEOM", "0") == "1":
                print(f"[GEOM][GT] stdout:\n{res_gt.stdout}")
                print(f"[GEOM][GT] stderr:\n{res_gt.stderr}")
        except subprocess.CalledProcessError as e:
            print(f"[GEOM][GT] Error running: {' '.join(cmd_gt)}")
            if hasattr(e, "stdout") and e.stdout:
                print(f"[GEOM][GT] stdout:\n{e.stdout}")
            if hasattr(e, "stderr") and e.stderr:
                print(f"[GEOM][GT] stderr:\n{e.stderr}")
            raise
        # 2. PRED render
        cmd_pred = [sys.executable, geometry_py, "--ann", pred_json_path, "--save_png", prd, "--no_show"]
        try:
            res_pred = subprocess.run(cmd_pred, check=True, capture_output=True, text=True)
            if os.environ.get("OR_VERBOSE_GEOM", "0") == "1":
                print(f"[GEOM][PRED] stdout:\n{res_pred.stdout}")
                print(f"[GEOM][PRED] stderr:\n{res_pred.stderr}")
        except subprocess.CalledProcessError as e:
            print(f"[GEOM][PRED] Error running: {' '.join(cmd_pred)}")
            if hasattr(e, "stdout") and e.stdout:
                print(f"[GEOM][PRED] stdout:\n{e.stdout}")
            if hasattr(e, "stderr") and e.stderr:
                print(f"[GEOM][PRED] stderr:\n{e.stderr}")
            raise
        # 3. OVERLAY render
        cmd_ovl = [sys.executable, geometry_py,
                   "--overlay_gt_json", gt_json_path,
                   "--overlay_pred_json", pred_json_path,
                   "--save_png", ovl,
                   "--no_show"]
        try:
            res_ovl = subprocess.run(cmd_ovl, check=True, capture_output=True, text=True)
            if os.environ.get("OR_VERBOSE_GEOM", "0") == "1":
                print(f"[GEOM][OVERLAY] stdout:\n{res_ovl.stdout}")
                print(f"[GEOM][OVERLAY] stderr:\n{res_ovl.stderr}")
        except subprocess.CalledProcessError as e:
            print(f"[GEOM][OVERLAY] Error running: {' '.join(cmd_ovl)}")
            if hasattr(e, "stdout") and e.stdout:
                print(f"[GEOM][OVERLAY] stdout:\n{e.stdout}")
            if hasattr(e, "stderr") and e.stderr:
                print(f"[GEOM][OVERLAY] stderr:\n{e.stderr}")
            raise
    except Exception:
        # Fallback: internal renderer
        gt_obj = load_json(gt_json_path)
        pred_obj = load_json(pred_json_path)
        # Use a temporary stem to generate internal files, then copy to the correct names
        render_geom_images(gt_obj, pred_obj, out_dir, stem)
        gtr_int = os.path.join(out_dir, f"{stem}_gt_geom.png")
        prd_int = os.path.join(out_dir, f"{stem}_pred_geom.png")
        ovl_geom = os.path.join(out_dir, f"{stem}_overlap_geom.png")
        # Copy to the specified output names
        try:
            if os.path.exists(gtr_int):
                shutil.copy2(gtr_int, gtr)
        except Exception:
            pass
        try:
            if os.path.exists(prd_int):
                shutil.copy2(prd_int, prd)
        except Exception:
            pass
        try: 
            if os.path.exists(ovl_geom):
                shutil.copy2(ovl_geom, ovl)
        except Exception:
            pass
    return gtr, prd, ovl

matplotlib.use("Agg")
import matplotlib.pyplot as plt

OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"

SYSTEM_PROMPT = (
    "You are a precise vision-to-JSON extractor. "
    "Given a silhouette image of a single tangram piece drawn within a 0..10 x 0..10 inner square, "
    "you must output ONLY a compact JSON with the requested fields (pos, angle, size). "
    "Rules:\n"
    "- Coordinate system: origin at bottom-left of the inner square; x rightwards, y upwards.\n"
    "- pos = [x, y] is the piece's reference point (centroid if unspecified). Range ~[0,10].\n"
    "- angle is in degrees, 0..360, counter-clockwise, relative to the x-axis.\n"
    "- size is a positive scalar factor relative to the canonical template size (1.0 = no scale).\n"
    "- DO NOT include any extra keys, comments, or prose. Return a minimal JSON object.\n"
    "Output must be valid minified JSON with numeric literals (no units, no strings for numbers)."
)

# --- Helpers -----------------------------------------------------------------

def to_data_url(image_path: str) -> str:
    mime = "image/png" if image_path.lower().endswith(".png") else "image/jpeg"
    with open(image_path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode("utf-8")
    return f"data:{mime};base64,{b64}"

def call_openrouter(
    model: str,
    system_prompt: str,
    user_prompt: str,
    image_path: Optional[str],
    args: Any,
    extra_messages: Optional[List[Dict[str,Any]]] = None,
    refinement_hint: str = "",
) -> str:
    api_key = getattr(args, "api_key", None) or os.getenv("OPENROUTER_API_KEY")
    if not api_key:
        raise RuntimeError("Missing OPENROUTER_API_KEY in environment or --api_key argument. export OPENROUTER_API_KEY='sk-or-...' or pass --api_key.")

    # If a refinement hint is provided, append it to the user prompt.
    if refinement_hint and refinement_hint.strip():
        user_prompt = user_prompt + "\n" + refinement_hint.strip()

    content: List[Dict[str, Any]] = []
    # user text first
    content.append({"type": "text", "text": user_prompt})
    # optional image (OpenAI-format: type=image_url)
    if image_path:
        content.append({
            "type": "image_url",
            "image_url": {"url": to_data_url(image_path)}
        })

    # Compose messages: [system] + extra_messages + [user]
    messages = [{"role": "system", "content": SYSTEM_PROMPT}]
    if extra_messages:
        messages.extend(extra_messages)
    messages.append({"role": "user", "content": content})

    payload = {
        "model": model,
        "messages": messages,
        "temperature": float(getattr(args, "temperature", 0.0)),
        "max_tokens": 300,
    }
    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }

    resp = requests.post(OPENROUTER_URL, headers=headers, json=payload, timeout=120)
    if resp.status_code != 200:
        raise RuntimeError(f"OpenRouter API error {resp.status_code}: {resp.text}")

    data = resp.json()
    try:
        msg = data["choices"][0]["message"]["content"]
    except Exception as e:
        raise RuntimeError(f"Unexpected OpenRouter response: {data}") from e
    return msg

def extract_json_block(text: str) -> str:
    """
    Extract the smallest JSON object from the model output.
    """
    # Try to find a {...} block
    m = re.search(r"\{.*\}", text, flags=re.DOTALL)
    if not m:
        raise ValueError(f"No JSON object found in model output: {text[:200]}...")
    return m.group(0)

def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def save_json(obj: Dict[str, Any], path: str) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)

def merge_pred_into_base(base: Dict[str, Any], pred: Dict[str, Any], mode: str) -> Dict[str, Any]:
    out = dict(base)  # shallow copy
    if mode in ("pos", "all") and "pos" in pred:
        out["pos"] = pred["pos"]
    if mode in ("angle", "all") and "angle" in pred:
        out["angle"] = pred["angle"]
    if mode in ("size", "all") and "size" in pred:
        out["size"] = pred["size"]
    return out

def l2_error(p1: List[float], p2: List[float]) -> float:
    return float(math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2))

def angle_diff_deg(a: float, b: float) -> float:
    """smallest absolute difference in degrees within [0, 180]."""
    d = abs((a - b) % 360.0)
    return float(min(d, 360.0 - d))

def size_rel_error(s_pred: float, s_gt: float) -> float:
    """Return absolute relative error |pred - gt|/gt (if gt>0), else absolute error."""
    if s_gt > 0:
        return float(abs(s_pred - s_gt) / s_gt)
    return float(abs(s_pred - s_gt))

def evaluate_basic(gt: Dict[str, Any], pred: Dict[str, Any]) -> Dict[str, float]:
    out = {}
    if "pos" in gt and "pos" in pred:
        out["l2_error_pos"] = l2_error(pred["pos"], gt["pos"])
    if "angle" in gt and "angle" in pred:
        out["angle_abs_diff_deg"] = angle_diff_deg(float(pred["angle"]), float(gt["angle"]))
    if "size" in gt and "size" in pred:
        out["size_rel_error"] = size_rel_error(float(pred["size"]), float(gt["size"]))
    return out

def run_overlay_render(gt_json: str, pred_json: str, out_dir: str, args: Any) -> Optional[str]:
    """Call the user's iou_from_overlay.py in batch mode to generate GT/PRED/OVERLAP renders.
    We create a temp run directory with the predicted json, then invoke the overlay script
    with --batch_my_run=<tmp_run> and --gt_dir=<gt_dir>. Returns the path to the summary CSV (or None)."""
    if getattr(args, "overlay_defer", False):
        # Defer actual rendering to a single pass after batch; only stage files now.
        # Just copy files into overlay_tmp, do not run overlay script.
        pass
    else:
        if not args.overlay_script or not os.path.exists(args.overlay_script):
            print("[RENDER] overlay_script not provided or not found; skip rendering.")
            return None
        if not args.overlay_gt_dir or not os.path.isdir(args.overlay_gt_dir):
            print("[RENDER] overlay_gt_dir not provided or invalid; skip rendering.")
            return None
    # temp run folder holding our single prediction (batch interface expects a dir)
    tmp_run = os.path.join(out_dir, "overlay_tmp")
    os.makedirs(tmp_run, exist_ok=True)
    # Overlay batch structure: subfolder per sample under overlay_tmp named after GT stem, pred JSON inside using GT basename (or pred basename if overlay_name_mode=="pred")
    gt_basename = os.path.basename(gt_json)
    gt_stem = os.path.splitext(gt_basename)[0]
    sample_dir = os.path.join(tmp_run, gt_stem)
    os.makedirs(sample_dir, exist_ok=True)
    dest_name = gt_basename if getattr(args, "overlay_name_mode", "gt") == "gt" else os.path.basename(pred_json)
    dest_json = os.path.join(sample_dir, dest_name)
    shutil.copy2(pred_json, dest_json)
    # Also place a copy named pred.json to match overlay script expectation
    try:
        pred_alias = os.path.join(sample_dir, 'pred.json')
        shutil.copy2(dest_json, pred_alias)
    except Exception as e:
        print(f"[RENDER] warn: failed to create pred.json alias: {e}")
    print(f"[RENDER] sample_dir: {sample_dir}")
    # Debug listing of tmp_run: show both sample folder and files
    try:
        files_here = sorted([os.path.join(dp, f) for dp, dn, fn in os.walk(tmp_run) for f in fn if f.endswith('.json')])
        print("[RENDER] tmp_run jsons:", [os.path.relpath(p, tmp_run) for p in files_here])
    except Exception:
        pass
    if getattr(args, "overlay_defer", False):
        # Only stage files, do not run overlay script.
        return None

    # Parse anchors
    gt_anchor, pred_anchor = (args.anchors.split(",") + ["centroid","centroid"])[:2]
    summary_csv = os.path.join(tmp_run, "iou_summary.csv")
    cmd = [
        "python", args.overlay_script,
        "--batch_my_run", tmp_run,
        "--gt_dir", args.overlay_gt_dir,
        "--map_to_axes_geom",
        "--gt_anchor", gt_anchor,
        "--pred_anchor", pred_anchor,
        "--size_mode_geom", args.size_mode_geom,
        "--dilate", str(args.dilate),
        "--write_iou_txt",
        "--summary_csv", summary_csv,
        "--save_debug",
        "--verbose",
    ]
    if args.overlay_extra:
        # naive split for extra flags
        cmd += args.overlay_extra.split()

    print("[RENDER] Running overlay renderer...\n ", " ".join(cmd))
    try:
        subprocess.run(cmd, check=True)
    except Exception as e:
        print(f"[RENDER] overlay script failed: {e}")
        return None

    # Tell user where to find images
    print(f"[RENDER] Done. Look for GT/PRED/OVERLAP images under: {tmp_run}")
    print(f"[RENDER] Summary CSV: {summary_csv}")
    return summary_csv if os.path.exists(summary_csv) else None


# --- Overlay exec helper for deferred rendering ---
def run_overlay_exec(out_dir: str, args: Any) -> Optional[str]:
    tmp_run = os.path.join(out_dir, "overlay_tmp")
    if not os.path.isdir(tmp_run):
        print("[RENDER] No overlay_tmp to process; skip.")
        return None
    if not args.overlay_script or not os.path.exists(args.overlay_script):
        print("[RENDER] overlay_script not provided or not found; skip rendering.")
        return None
    if not args.overlay_gt_dir or not os.path.isdir(args.overlay_gt_dir):
        print("[RENDER] overlay_gt_dir not provided or invalid; skip rendering.")
        return None
    gt_anchor, pred_anchor = (args.anchors.split(",") + ["centroid","centroid"])[:2]
    summary_csv = os.path.join(tmp_run, "iou_summary.csv")
    cmd = [
        "python", args.overlay_script,
        "--batch_my_run", tmp_run,
        "--gt_dir", args.overlay_gt_dir,
        "--map_to_axes_geom",
        "--gt_anchor", gt_anchor,
        "--pred_anchor", pred_anchor,
        "--size_mode_geom", args.size_mode_geom,
        "--dilate", str(args.dilate),
        "--write_iou_txt",
        "--summary_csv", summary_csv,
        "--save_debug",
        "--verbose",
    ]
    if getattr(args, "overlay_extra", None):
        cmd += args.overlay_extra.split()
    print("[RENDER] Executing overlay on accumulated batch...\n ", " ".join(cmd))
    try:
        subprocess.run(cmd, check=True)
    except Exception as e:
        print(f"[RENDER] overlay script failed: {e}")
        return None
    print(f"[RENDER] Done. Summary CSV: {summary_csv}")
    return summary_csv if os.path.exists(summary_csv) else None

# --- Prompts -----------------------------------------------------------------

def make_user_prompt(mode: str) -> str:
    if mode == "pos":
        return (
            "Task: From the image, infer ONLY the 'pos' field as [x, y] within ~[0,10]. "
            "Return JSON like: {\"pos\": [x, y]}."
        )
    elif mode == "angle":
        return (
            "Task: From the image, infer ONLY the 'angle' in degrees (0..360, CCW). "
            "Return JSON like: {\"angle\": deg}."
        )
    elif mode == "size":
        return (
            "Task: From the image, infer ONLY the 'size' scaling factor (>0, 1.0 = canonical size). "
            "Return JSON like: {\"size\": s}."
        )
    elif mode == "all":
        return (
            "Task: From the image, infer 'pos' [x,y], 'angle' (deg), and 'size' (>0). "
            "Return JSON like: {\"pos\":[x,y],\"angle\":deg,\"size\":s}."
        )
    else:
        raise ValueError(f"Unknown mode: {mode}")

# --- I/O patterns ------------------------------------------------------------

def default_out_name(image_path: str, mode: str) -> str:
    base = os.path.splitext(os.path.basename(image_path))[0]
    return f"{base}_pred_{mode}.json"

def find_pair_json(gt_dir: str, image_basename: str) -> Optional[str]:
    # try exact name with .json swap, else prefix match
    cand = os.path.join(gt_dir, os.path.splitext(image_basename)[0] + ".json")
    if os.path.exists(cand):
        return cand
    # fallback: glob by base
    stem = os.path.splitext(image_basename)[0]
    matches = glob.glob(os.path.join(gt_dir, f"{stem}*.json"))
    return matches[0] if matches else None

# --- Validation helpers ------------------------------------------------------

def _has_required_fields(obj: Dict[str, Any], mode: str) -> bool:
    try:
        if mode == "pos":
            p = obj.get("pos", None)
            return isinstance(p, list) and len(p) == 2 and all(isinstance(v, (int, float)) for v in p)
        if mode == "angle":
            a = obj.get("angle", None)
            return isinstance(a, (int, float))
        if mode == "size":
            s = obj.get("size", None)
            return isinstance(s, (int, float)) and s > 0
        if mode == "all":
            return _has_required_fields(obj, "pos") and _has_required_fields(obj, "angle") and _has_required_fields(obj, "size")
        return False
    except Exception:
        return False

# --- Simple geometry renderer (for visualization only) -----------------------
CANVAS_XY = (0.0, 10.0)

def _piece_scale(obj: Dict[str, Any]) -> float:
    # Prefer 'size' if present, otherwise 'scale', fallback 1.0
    s = obj.get("size", None)
    if isinstance(s, (int, float)) and s > 0:
        return float(s)
    s = obj.get("scale", None)
    if isinstance(s, (int, float)) and s > 0:
        return float(s)
    return 1.0

def _base_poly(piece_type: str) -> np.ndarray:
    """Canonical unit shapes centered near origin; rough templates for visualization."""
    t = (piece_type or "").lower()
    if "square" in t:
        pts = np.array([[-0.5,-0.5],[0.5,-0.5],[0.5,0.5],[-0.5,0.5]], dtype=float)
    elif "parallelog" in t:
        pts = np.array([[-0.6,-0.5],[0.6,-0.5],[0.3,0.5],[-0.9,0.5]], dtype=float)
    elif "big" in t and "triangle" in t:
        pts = 2.0 * (np.array([[0.0,0.0],[1.0,0.0],[0.0,1.0]], dtype=float) - 0.5)
    elif "medium" in t and "triangle" in t:
        pts = 1.4 * (np.array([[0.0,0.0],[1.0,0.0],[0.0,1.0]], dtype=float) - 0.5)
    elif "small" in t and "triangle" in t:
        pts = (np.array([[0.0,0.0],[1.0,0.0],[0.0,1.0]], dtype=float) - 0.5)
    else:
        # default diamond
        pts = np.array([[0.0,-0.7],[0.7,0.0],[0.0,0.7],[-0.7,0.0]], dtype=float)
    return pts

def _xf_poly(poly: np.ndarray, pos: List[float], angle_deg: float, scale: float) -> np.ndarray:
    th = math.radians(float(angle_deg or 0.0))
    R = np.array([[math.cos(th), -math.sin(th)], [math.sin(th), math.cos(th)]], dtype=float)
    P = (poly * float(scale)) @ R.T
    P += np.array([[float(pos[0]), float(pos[1])]], dtype=float)
    return P

def render_geom_images(gt_obj: Dict[str, Any], pred_obj: Dict[str, Any], save_dir: str, stem: str) -> None:
    os.makedirs(save_dir, exist_ok=True)
    # pull fields
    gt_type = (gt_obj or {}).get("type", pred_obj.get("type", ""))
    gt_pos = (gt_obj or {}).get("pos", [5.0,5.0])
    gt_ang = (gt_obj or {}).get("angle", 0.0)
    gt_s   = _piece_scale(gt_obj or {})

    pr_type = pred_obj.get("type", gt_type)
    pr_pos  = pred_obj.get("pos", gt_pos)
    pr_ang  = pred_obj.get("angle", gt_ang)
    # 如果预测里没有 size/scale，则沿用 GT 的比例，避免几何大小不一致
    pred_scale = _piece_scale(pred_obj)
    pr_s = pred_scale if pred_scale != 1.0 or gt_s == 1.0 else gt_s

    gt_poly = _xf_poly(_base_poly(gt_type), gt_pos, gt_ang, gt_s)
    pr_poly = _xf_poly(_base_poly(pr_type), pr_pos, pr_ang, pr_s)

    def _plot_one(polys, labels, out_path, title):
        fig = plt.figure(figsize=(6,6), dpi=180)
        ax = fig.add_subplot(111)
        ax.set_xlim(*CANVAS_XY); ax.set_ylim(*CANVAS_XY)
        ax.set_aspect('equal', adjustable='box')
        ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
        for poly, lbl in zip(polys, labels):
            xs = list(poly[:,0]) + [poly[0,0]]
            ys = list(poly[:,1]) + [poly[0,1]]
            ax.plot(xs, ys, linewidth=2)
            ax.fill(poly[:,0], poly[:,1], alpha=0.15)
            ax.text(float(np.mean(poly[:,0])), float(np.mean(poly[:,1])), lbl, fontsize=10)
        ax.set_title(title)
        plt.tight_layout()
        fig.savefig(out_path)
        plt.close(fig)

    def _plot_overlay(gt_poly, pr_poly, out_path, title):
        fig = plt.figure(figsize=(6,6), dpi=180)
        ax = fig.add_subplot(111)
        ax.set_xlim(*CANVAS_XY); ax.set_ylim(*CANVAS_XY)
        ax.set_aspect('equal', adjustable='box')
        ax.grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
        # GT: thin gray wireframe only
        xs = list(gt_poly[:,0]) + [gt_poly[0,0]]
        ys = list(gt_poly[:,1]) + [gt_poly[0,1]]
        ax.plot(xs, ys, linewidth=2.5, color='#666666')
        ax.text(float(np.mean(gt_poly[:,0])), float(np.mean(gt_poly[:,1])), 'GT', fontsize=10, color='#666666')
        # PRED: filled with strong edge
        ax.fill(pr_poly[:,0], pr_poly[:,1], alpha=0.35)
        xs2 = list(pr_poly[:,0]) + [pr_poly[0,0]]
        ys2 = list(pr_poly[:,1]) + [pr_poly[0,1]]
        ax.plot(xs2, ys2, linewidth=3.0, color='black')
        ax.text(float(np.mean(pr_poly[:,0])), float(np.mean(pr_poly[:,1])), 'P', fontsize=10)
        ax.set_title(title)
        plt.tight_layout()
        fig.savefig(out_path)
        plt.close(fig)

    # save: gt, pred, overlap
    _plot_one([gt_poly],  ["GT"],    os.path.join(save_dir, f"{stem}_gt_geom.png"),   "gt.json (geometry)")
    _plot_one([pr_poly],  ["P"],     os.path.join(save_dir, f"{stem}_pred_geom.png"),  "pred.json (geometry)")
    overlap_path = os.path.join(save_dir, f"{stem}_overlap_geom.png")
    _plot_overlay(gt_poly, pr_poly, overlap_path, "pred.json")

# --- Main --------------------------------------------------------------------
# --- Main --------------------------------------------------------------------
# --- IoU via rasterization (no external deps) ---------------------------------

def _poly_from_obj(obj: Dict[str, Any]) -> np.ndarray:
    ptype = (obj or {}).get("type", "triangle")
    pos   = (obj or {}).get("pos", [5.0,5.0])
    ang   = (obj or {}).get("angle", 0.0)
    scl   = _piece_scale(obj or {})
    return _xf_poly(_base_poly(ptype), pos, ang, scl)

def iou_from_objs(gt_obj: Dict[str, Any], pred_obj: Dict[str, Any], size: int = 512) -> float:
    """Rasterize two polygons onto a size×size mask (0..10 canvas) and compute IoU."""
    scale = (size - 1) / 10.0
    def _mask(poly: np.ndarray) -> Image.Image:
        pts = [(float(x)*scale, (float(y))*scale) for x, y in poly]
        img = Image.new("1", (size, size), 0)
        drw = ImageDraw.Draw(img)
        drw.polygon(pts, outline=1, fill=1)
        return img
    poly_gt  = _poly_from_obj(gt_obj)
    poly_pr  = _poly_from_obj(pred_obj)
    m1 = _mask(poly_gt)
    m2 = _mask(poly_pr)
    inter = ImageChops.logical_and(m1, m2)
    union = ImageChops.logical_or(m1, m2)
    inter_cnt = float(np.array(inter, dtype=np.uint8).sum())
    union_cnt = float(np.array(union, dtype=np.uint8).sum())
    if union_cnt <= 0: return 0.0
    return inter_cnt / union_cnt

# --- Few-shot message builder -------------------------------------------------
def build_fewshot_messages(
    k: int,
    icl_img_dir: Optional[str],
    icl_json_dir: Optional[str],
    exclude_stem: Optional[str] = None
) -> List[Dict[str, Any]]:
    if not k or k <= 0 or not icl_img_dir or not icl_json_dir:
        return []
    msgs: List[Dict[str, Any]] = []
    # Collect all medium-triangle samples
    imgs_all = sorted(glob.glob(os.path.join(icl_img_dir, "*medium*triangle*.png")))
    # Exclude the current test sample (by stem) to avoid leaking GT into ICL
    if exclude_stem:
        imgs_all = [p for p in imgs_all if os.path.splitext(os.path.basename(p))[0] != exclude_stem]
    # Take first k after exclusion (could randomize if needed)
    imgs = imgs_all[:k]
    for img in imgs:
        stem = os.path.splitext(os.path.basename(img))[0]
        j = find_pair_json(icl_json_dir, os.path.basename(img))
        if not j:
            continue
        try:
            gt = load_json(j)
            # user gives image and asks for pos
            msgs.append({
                "role": "user",
                "content": [
                    {"type": "text", "text": make_user_prompt("pos")},
                    {"type": "image_url", "image_url": {"url": to_data_url(img)}}
                ]
            })
            # assistant answers with ground-truth minimal JSON (only pos)
            ans = json.dumps({"pos": gt.get("pos", [5.0, 5.0])})
            msgs.append({"role": "assistant", "content": ans})
        except Exception:
            continue
    return msgs

def run_single(args) -> Tuple[Dict[str, Any], Dict[str, float]]:

    # load base JSON (gt) to keep other fields intact
    base = load_json(args.gt_json) if args.gt_json else {}

    # --- ICL few-shot and iterative refinement loop ---
    fewshots = build_fewshot_messages(
        getattr(args, "icl_k", 0),
        getattr(args, "icl_img_dir", None),
        getattr(args, "icl_json_dir", None),
        exclude_stem=os.path.splitext(os.path.basename(args.image))[0] if getattr(args, "image", None) else None,
    )
    if args.verbose:
        try:
            excl = os.path.splitext(os.path.basename(args.image))[0] if getattr(args, "image", None) else None
            print(f"[ICL] Using {len(fewshots)//2} examples (excluded current: {bool(excl)})")
        except Exception:
            pass
    loop_iters = max(1, int(getattr(args, "loop_iters", 1)))
    lam = float(getattr(args, "reward_lambda", 0.1))

    models_to_try = [args.model]
    if getattr(args, "alt_model", None):
        models_to_try.append(args.alt_model)

    best = {"fields": None, "iou": -1.0, "l2": 1e9, "reward": -1e9}
    attempts = []
    pred_fields = None
    last_err = None
    msg = ""

    for it in range(loop_iters):
        refinement = ""
        if it > 0 and best["fields"] is not None and args.gt_json:
            try:
                dx = float(base.get("pos", [5.0,5.0])[0]) - float(best["fields"].get("pos", [5.0,5.0])[0])
                dy = float(base.get("pos", [5.0,5.0])[1]) - float(best["fields"].get("pos", [5.0,5.0])[1])
                dx = max(-1.0, min(1.0, dx))
                dy = max(-1.0, min(1.0, dy))
                refinement = f"Hint: previous IoU={best['iou']:.3f}. Try a small correction ~({dx:+.2f},{dy:+.2f})."
            except Exception:
                refinement = ""

        cand_fields = None
        last_err = None
        for mdl_idx, mdl in enumerate(models_to_try):
            for attempt in range(1, (args.max_retries or 1) + 1):
                try:
                    msg = call_openrouter(
                        model=mdl,
                        system_prompt=SYSTEM_PROMPT,
                        user_prompt=make_user_prompt(args.mode),
                        image_path=args.image,
                        args=args,
                        extra_messages=fewshots,
                        refinement_hint=refinement,
                    )
                    if args.verbose:
                        print(f"[RAW MODEL OUTPUT] (iter={it+1}, model={mdl}, try={attempt})", msg[:500])
                    jb = extract_json_block(msg)
                    maybe = json.loads(jb)
                    if _has_required_fields(maybe, args.mode):
                        cand_fields = maybe
                        break
                    else:
                        last_err = RuntimeError(f"Model {mdl} returned invalid JSON for mode={args.mode}: {maybe}")
                except Exception as e:
                    last_err = e
                    continue
            if cand_fields is not None:
                break
        if cand_fields is None:
            if args.allow_gt_fallback:
                cand_fields = {}
            else:
                raise RuntimeError(f"Prediction failed after retries (iter={it+1}): {last_err}")

        merged_cand = merge_pred_into_base(base, cand_fields, args.mode)
        m = evaluate_basic(base, merged_cand)
        iou = 0.0
        try:
            iou = iou_from_objs(base, merged_cand)
            m["iou_geom"] = float(iou)
        except Exception:
            pass

        l2 = float(m.get("l2_error_pos", 0.0)) if m else 0.0
        reward = float(iou - lam * (l2/10.0))

        # 记录当前迭代
        attempts.append({"iter": it+1, "pred": cand_fields, "metrics": m, "reward": reward})

        # IoU 达到阈值则提前停止
        min_iou_thr = getattr(args, "min_iou", 0.5)
        if iou >= float(min_iou_thr):
            best = {"fields": cand_fields, "iou": iou, "l2": l2, "reward": reward}
            if args.verbose:
                print(f"[EARLY STOP] IoU {iou:.3f} >= threshold {float(min_iou_thr):.3f}")
            break

        # 否则按 reward 取更优
        if reward > best["reward"]:
            best = {"fields": cand_fields, "iou": iou, "l2": l2, "reward": reward}

    # --- Final local refinement (deterministic, no extra API calls) ---
    try:
        min_thr = float(getattr(args, "min_iou", 0.5))
        mode_now = str(getattr(args, "mode", "pos"))
        if best.get("fields") is not None and best.get("iou", 0.0) < min_thr and (mode_now in ("pos", "all")) and args.gt_json:
            base_gt = base  # ground truth object loaded earlier
            cur = dict(best["fields"])  # current best fields
            # ensure pos exists in current best
            bx, by = cur.get("pos", base_gt.get("pos", [5.0, 5.0]))
            try:
                bx = float(bx); by = float(by)
            except Exception:
                bx, by = float(bx[0]), float(by[1]) if isinstance(bx, (list, tuple)) else (5.0, 5.0)
            # keep angle/size from current best or fall back to GT
            if mode_now == "pos":
                if "angle" not in cur and "angle" in base_gt:
                    cur["angle"] = base_gt.get("angle", 0.0)
                if "size" not in cur:
                    if "size" in base_gt:
                        cur["size"] = base_gt.get("size", 1.0)
                    elif "scale" in base_gt:
                        cur["size"] = base_gt.get("scale", 1.0)
            # hierarchical steps (canvas units)
            steps = [0.6, 0.3, 0.15]
            best_loc = {"fields": dict(cur), "iou": float(best.get("iou", 0.0))}
            for step in steps:
                # 9-point neighborhood around (bx, by)
                offsets = [(-step,-step),(-step,0.0),(-step,step),(0.0,-step),(0.0,0.0),(0.0,step),(step,-step),(step,0.0),(step,step)]
                improved = False
                for dx, dy in offsets:
                    nx = max(0.0, min(10.0, bx + dx))
                    ny = max(0.0, min(10.0, by + dy))
                    cand = dict(cur)
                    cand["pos"] = [float(nx), float(ny)]
                    merged_cand = merge_pred_into_base(base_gt, cand, mode_now)
                    iou_try = 0.0
                    try:
                        iou_try = float(iou_from_objs(base_gt, merged_cand))
                    except Exception:
                        iou_try = 0.0
                    if iou_try > best_loc["iou"]:
                        best_loc = {"fields": dict(cand), "iou": iou_try}
                        bx, by = nx, ny
                        improved = True
                        if args.verbose:
                            print(f"[LOCAL SEARCH] step={step:.2f} → IoU {iou_try:.3f} at pos=({nx:.3f},{ny:.3f})")
                        if iou_try >= min_thr:
                            break
                # early exit if reached threshold
                if best_loc["iou"] >= min_thr:
                    break
                # if no improvement at this scale, still proceed to finer step
            if best_loc["iou"] > best.get("iou", -1.0):
                best["fields"] = best_loc["fields"]
                best["iou"] = float(best_loc["iou"])
                # update reward based on refined IoU
                try:
                    l2_ref = l2_error(best["fields"].get("pos", [bx, by]), base_gt.get("pos", [bx, by]))
                    best["l2"] = float(l2_ref)
                    best["reward"] = float(best["iou"] - lam * (best["l2"]/10.0))
                except Exception:
                    pass
                if args.verbose:
                    print(f"[LOCAL SEARCH] Adopted refined pos with IoU={best['iou']:.3f}")
                # record a pseudo-attempt for transparency
                try:
                    attempts.append({"iter": "local_search", "pred": dict(best["fields"]), "metrics": {"iou_geom": best["iou"]}, "reward": best.get("reward", 0.0)})
                except Exception:
                    pass
    except Exception as _e:
        if args.verbose:
            print(f"[LOCAL SEARCH] skipped due to error: {_e}")

    pred_fields = best["fields"]

    # merge into base (keep others unchanged)
    merged = merge_pred_into_base(base, pred_fields, args.mode)

    # write output JSON
    out_json = args.out_json or os.path.join(args.out_dir or ".", default_out_name(args.image, args.mode))

    # Dump raw API response if requested
    if getattr(args, "dump_raw", False):
        try:
            raw_path = (args.out_json or os.path.join(args.out_dir or ".", default_out_name(args.image, args.mode))).replace(".json", "_raw.txt")
            with open(raw_path, "w", encoding="utf-8") as rf:
                rf.write(str(msg))
            if args.verbose:
                print(f"[DEBUG] dumped raw response to: {raw_path}")
        except Exception as _:
            pass

    save_json(merged, out_json)

    # Log iteration attempts
    try:
        log_path = (args.out_json or os.path.join(args.out_dir or ".", default_out_name(args.image, args.mode))).replace(".json", "_iterlog.json")
        with open(log_path, "w", encoding="utf-8") as lf:
            json.dump({"lambda": lam, "attempts": attempts, "best": best}, lf, ensure_ascii=False, indent=2)
        if args.verbose:
            print(f"[DEBUG] wrote iter log: {log_path}")
    except Exception:
        pass

    # basic evaluation if gt provided
    metrics = {}
    if args.gt_json:
        metrics = evaluate_basic(base, merged)
        if metrics:
            print("[EVAL]", json.dumps(metrics, ensure_ascii=False))

    # --- geometry overlay images (GT / PRED / OVERLAP) ---
    try:
        # derive stem (remove _pred_<mode> suffix from filename)
        stem = os.path.splitext(os.path.basename(out_json))[0]
        suffix = f"_pred_{args.mode}"
        if stem.endswith(suffix):
            stem = stem[: -len(suffix)]
        ov_dir = os.path.join(os.path.dirname(out_json), "overlay_tmp", stem)

        if args.geometry_script and os.path.exists(args.geometry_script) and args.gt_json:
            # External geometry renderer (authoritative)
            run_geometry_render(args.gt_json, out_json, ov_dir, args.geometry_script)
        else:
            # Fallback: internal matplotlib renderer
            render_geom_images(base if args.gt_json else {}, merged, ov_dir, stem)
            # Ensure overlay_json_vs_pred.png equals geometry overlap
            try:
                _overlap_geom = os.path.join(ov_dir, f"{stem}_overlap_geom.png")
                _overlay_img = os.path.join(ov_dir, "overlay_json_vs_pred.png")
                if os.path.exists(_overlap_geom):
                    shutil.copy2(_overlap_geom, _overlay_img)
            except Exception:
                pass
    except Exception as e:
        if args.verbose:
            print(f"[WARN] geometry render failed: {e}")

    if getattr(args, "render_three", False):
        # Attempt to render via overlay script. We pass gt_json (baseline) and our out_json.
        run_overlay_render(args.gt_json, out_json, os.path.dirname(out_json), args)

    print(f"[OK] Wrote: {out_json}")
    return merged, metrics

def run_batch(args):
    os.makedirs(args.out_dir, exist_ok=True)
    pngs = sorted(glob.glob(os.path.join(args.in_dir, "*.png")) + glob.glob(os.path.join(args.in_dir, "*.jpg")))
    if getattr(args, "n", None):
        try:
            pngs = pngs[:int(args.n)]
        except Exception:
            pass
    if not pngs:
        raise RuntimeError(f"No images found in: {args.in_dir}")

    all_metrics = []
    for i, img in enumerate(pngs, 1):
        base = os.path.basename(img)
        gt_path = find_pair_json(args.gt_dir, os.path.basename(img)) if args.gt_dir else None
        if args.gt_dir and not gt_path:
            print(f"[WARN] No GT json for {base}, skip evaluation.")
        try:
            merged, metrics = run_single(argparse.Namespace(
                model=args.model,
                mode=args.mode,
                image=img,
                gt_json=gt_path,
                out_json=None,
                out_dir=args.out_dir,
                geometry_script=args.geometry_script, 
                verbose=args.verbose,
                api_key=args.api_key,
                overlay_script=args.overlay_script,
                overlay_gt_dir=args.overlay_gt_dir,
                render_three=args.render_three,
                dilate=args.dilate,
                anchors=args.anchors,
                size_mode_geom=args.size_mode_geom,
                overlay_extra=args.overlay_extra,
                overlay_name_mode=args.overlay_name_mode,
                max_retries=args.max_retries,
                alt_model=args.alt_model,
                allow_gt_fallback=args.allow_gt_fallback,
                dump_raw=args.dump_raw,
                overlay_defer=getattr(args, "overlay_defer", False),
                icl_k=args.icl_k,
                icl_img_dir=args.icl_img_dir,
                icl_json_dir=args.icl_json_dir,
                loop_iters=args.loop_iters,
                reward_lambda=args.reward_lambda,
                min_iou=getattr(args, "min_iou", 0.5),
                temperature=getattr(args, "temperature", 0.0),
            ))
            out_json = args.out_json or os.path.join(args.out_dir or ".", default_out_name(img, args.mode))
            # DEFER: do not call run_overlay_render here if overlay_defer
            if not getattr(args, "overlay_defer", False):
                if args.render_three and gt_path:
                    run_overlay_render(gt_path, out_json, args.out_dir, args)

            if metrics:
                metrics_row = {"name": os.path.splitext(base)[0], **metrics}
                # 读取 iterlog 的 best.iou 附加到 CSV
                try:
                    out_json = args.out_json or os.path.join(args.out_dir or ".", default_out_name(img, args.mode))
                    iterlog = out_json.replace(".json", "_iterlog.json")
                    if os.path.exists(iterlog):
                        with open(iterlog, "r", encoding="utf-8") as lf:
                            data = json.load(lf)
                        biou = None
                        if isinstance(data, dict):
                            b = data.get("best", {})
                            biou = b.get("iou", b.get("iou_geom"))
                        if isinstance(biou, (int, float)):
                            metrics_row["iou_best"] = float(biou)
                except Exception:
                    pass
                all_metrics.append(metrics_row)



                
        except Exception as e:
            print(f"[ERR] {base}: {e}")

    # After batch, if deferred overlay, execute overlay rendering once
    if getattr(args, "overlay_defer", False):
        if args.render_three and args.overlay_script:
            run_overlay_exec(args.out_dir, args)

    # write a simple CSV for metrics if any
    if all_metrics:
        import csv
        csv_path = os.path.join(args.out_dir, f"metrics_{args.mode}.csv")
        with open(csv_path, "w", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=sorted(all_metrics[0].keys()))
            writer.writeheader()
            for row in all_metrics:
                writer.writerow(row)
        print(f"[OK] Wrote metrics CSV: {csv_path}")

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", type=str, default="openai/gpt-4o-mini", help="Vision-capable model on OpenRouter.")
    ap.add_argument("--mode", type=str, choices=["pos", "angle", "size", "all"], required=True, help="Which fields to predict.")
    ap.add_argument("--image", type=str, help="Path to a single PNG/JPG image.")
    ap.add_argument("--gt_json", type=str, help="Path to the original JSON (to keep other fields unchanged, and for evaluation).")
    ap.add_argument("--out_json", type=str, help="Output JSON path (if not set, use out_dir + default name).")
    ap.add_argument("--in_dir", type=str, help="Batch mode: directory of images.")
    ap.add_argument("--gt_dir", type=str, help="Batch mode: directory of GT jsons with matching basenames.")
    ap.add_argument("--out_dir", type=str, help="Output directory for batch or when out_json is not specified.")
    ap.add_argument("--verbose", action="store_true", help="Print raw model output.")
    ap.add_argument("--api_key", type=str, help="OpenRouter API key (optional, otherwise read from env).")

    ap.add_argument("--overlay_script", type=str, help="Path to iou_from_overlay.py for rendering GT/PRED/OVERLAP.")
    ap.add_argument("--overlay_gt_dir", type=str, help="Directory containing GT JSONs (for overlay script batch mode).")
    ap.add_argument("--render_three", action="store_true", help="Render GT, PRED, and OVERLAP using the overlay script after prediction.")
    ap.add_argument("--dilate", type=int, default=2, help="Dilation for overlap in overlay script (default: 2).")
    ap.add_argument("--anchors", type=str, default="centroid,centroid", help="gt_anchor,pred_anchor for overlay script (default: centroid,centroid).")
    ap.add_argument("--size_mode_geom", type=str, default="keep", help="size mode for overlay script (default: keep).")
    ap.add_argument("--overlay_extra", type=str, default="", help="Extra flags to pass to overlay script (advanced).")
    ap.add_argument("--overlay_name_mode", type=str, choices=["gt","pred"], default="gt", help="When rendering, name the predicted JSON as the GT basename (gt) or keep the original pred basename (pred). Default: gt.")
    ap.add_argument("--overlay_defer", action="store_true", help="In batch mode, only stage files during per-item processing and run the overlay renderer once at the end.")
    ap.add_argument("--geometry_script", type=str,
                    help="Path to external geometry renderer (supports --ann / --overlay_gt_json / --overlay_pred_json / --save_png / --no_show). If provided, it will be used to render gt_render.png, pred_render.png and overlay_json_vs_pred.png.")
    ap.add_argument("--max_retries", type=int, default=2, help="Max retries if model output is empty/invalid.")
    ap.add_argument("--alt_model", type=str, help="Fallback model if primary fails (e.g., google/gemini-1.5-flash).")
    ap.add_argument("--allow_gt_fallback", action="store_true", help="If set, when prediction fails after retries, keep GT fields instead of error.")
    ap.add_argument("--dump_raw", action="store_true", help="Dump raw OpenRouter response JSON next to out_json for debugging.")

    # New CLI arguments for ICL and reward loop
    ap.add_argument("--icl_k", type=int, default=0, help="#few-shot examples to include (medium triangle only)")
    ap.add_argument("--icl_img_dir", type=str, help="Directory of ICL example images")
    ap.add_argument("--icl_json_dir", type=str, help="Directory of ICL example GT jsons")
    ap.add_argument("--loop_iters", type=int, default=2, help="Refinement loop iterations (>=1, default=2).")
    ap.add_argument("--reward_lambda", type=float, default=0.1, help="Reward = IoU - lambda*(L2/10).")
    ap.add_argument("--min_iou", type=float, default=0.5,
                help="Minimum IoU threshold to stop refinement early (default=0.5).")
    ap.add_argument("--n", type=int, help="Limit number of images to process in batch.")
    ap.add_argument("--temperature", type=float, default=0.0,
                    help="Sampling temperature for the VLM (higher = more diverse). Default 0.0")

    
    args = ap.parse_args()

    if args.in_dir:
        if not args.out_dir:
            raise RuntimeError("Batch mode requires --out_dir")
        run_batch(args)
    else:
        if not args.image:
            raise RuntimeError("Single mode requires --image")
        run_single(args)

if __name__ == "__main__":
    main()
