# qwen_centroid_demo.py
# Compare model-predicted {"pos":[x,y]} with the GT pos from our JSON
import os, sys, json, random, time, math, re, argparse, base64
from pathlib import Path

# NOTE:
# - We lazily import Qwen libs only if backend=='qwen'
# - We lazily import OpenAI SDK only if backend=='openai'

# ---- CLI ----
parser = argparse.ArgumentParser(description="Evaluate Qwen pos-only predictions against GT pos.")
parser.add_argument("--img_dir", type=str, default=None, help="Directory of PNG images.")
parser.add_argument("--json_dir", type=str, default=None, help="Directory of GT JSON files.")
parser.add_argument("--n", type=int, default=10, help="How many images to sample.")
parser.add_argument("--seed", type=int, default=2025, help="Random seed.")
parser.add_argument("--pos_def", type=str, choices=["corner","com"], default="corner", help="Meaning of 'pos': 'corner' (anchor corner) or 'com' (centre of mass).")
parser.add_argument("--icl_file", type=str, default=None, help="Path to a small JSON file of few-shot examples: [{\"image\":\"...\",\"pos\":[x,y]}].")
parser.add_argument("--icl_k", type=int, default=3, help="How many ICL examples to include (if --icl_file is provided).")
parser.add_argument("--icl_auto", type=int, default=0, help="If >0 and --icl_file not provided, auto-sample this many ICL examples from the dataset (excluded from test set).")
parser.add_argument("--backend", type=str, choices=["qwen","openai"], default="qwen", help="Which model backend to use.")
parser.add_argument("--openai_model", type=str, default="gpt-4o-mini", help="OpenAI vision-capable model name.")
parser.add_argument("--openai_api_key", type=str, default=None, help="If not provided, will read from env OPENAI_API_KEY.")
parser.add_argument("--openai_base_url", type=str, default=None, help="Base URL for OpenAI-compatible APIs. Use 'https://openrouter.ai/api/v1' for OpenRouter.")
parser.add_argument("--max_new", type=int, default=64, help="Max new tokens for generation.")
parser.add_argument("--predict_angle", action="store_true", help="If set, ask the model to also predict angle in degrees.")
# ---- New CLI args ----
parser.add_argument("--all", action="store_true", help="Evaluate ALL eligible images (ignore --n).")
parser.add_argument("--out_dir", type=str, default=None, help="Directory to save results. Default runs/pos_eval/{timestamp}.")
parser.add_argument("--geom_render", action="store_true", help="Also render GT, Pred, and Overlay via geometry.py and compute IoU from the rendered masks.")
parser.add_argument("--save_iou_debug", action="store_true", help="Save IoU debug masks from iou_from_overlay.py for each sample.")
args = parser.parse_args()

# ---- Paths & params ----
ROOT = Path(__file__).resolve().parent

# Python interpreter and helper scripts
PY = sys.executable  # current Python interpreter
GEOM_PY = (ROOT / "geometry.py").resolve()
IOU_PY  = (ROOT / "iou_from_overlay.py").resolve()

def pick_existing_dir(candidates, must_have_ext=None):
    """
    Return the first existing directory from candidates.
    If must_have_ext is provided (e.g., '.png'), ensure at least one file with that ext exists.
    """
    for p in candidates:
        p = Path(p)
        if p.is_dir():
            if must_have_ext:
                if any(p.glob(f"*{must_have_ext}")):
                    return p
            else:
                return p
    return None

# Preferred dirs from CLI
IMG_DIR = Path(args.img_dir) if args.img_dir else None
JSON_DIR = Path(args.json_dir) if args.json_dir else None

# Auto-detect if not provided
if IMG_DIR is None:
    IMG_DIR = pick_existing_dir([
        ROOT / "dataset" / "onepiece_images",
        ROOT / "kilogram-main" / "dataset" / "onepiece_images",
        ROOT / "dataset" / "outlines_qwen_png",  # fallback: outlines
    ], must_have_ext=".png")

if JSON_DIR is None:
    JSON_DIR = pick_existing_dir([
        ROOT / "dataset" / "onepiece_from_svg",
        ROOT / "kilogram-main" / "dataset" / "onepiece_from_svg",
    ], must_have_ext=".json")

if IMG_DIR is None or JSON_DIR is None:
    msg = ["[PATH ERROR] Could not locate required directories."]
    msg.append(f"  Tried IMG candidates under: {ROOT}")
    msg.append(f"  Tried JSON candidates under: {ROOT}")
    msg.append("  Fix by either:")
    msg.append("    a) passing absolute paths with --img_dir and --json_dir")
    msg.append("    b) moving data to ./dataset or ./kilogram-main/dataset")
    raise FileNotFoundError("\n".join(msg))

N_SAMPLES = max(1, int(args.n))
SEED = int(args.seed)

ts = time.strftime("%Y%m%d_%H%M%S")
if args.out_dir:
    SAVE_DIR = Path(args.out_dir)
else:
    SAVE_DIR = ROOT / "runs" / "pos_eval" / ts
SAVE_DIR.mkdir(parents=True, exist_ok=True)

print(f"[PATH] Using IMG_DIR = {IMG_DIR}")
print(f"[PATH] Using JSON_DIR = {JSON_DIR}")

def _img_to_data_url(p: Path) -> str:
    with open(p, "rb") as f:
        b64 = base64.b64encode(f.read()).decode("utf-8")
    return f"data:image/png;base64,{b64}"

random.seed(SEED)

def load_icl_examples(path: Path):
    """
    Load a few-shot set from JSON.
    Accepts:
      - list of {"image": "...", "pos":[x,y]}
      - or dict: { "filename.png": [x,y], ... }
    Returns a list of (image_path, [x,y]) with absolute image paths.
    """
    if path is None:
        return []
    p = Path(path)
    if not p.exists():
        print(f"[WARN] --icl_file not found: {p}")
        return []
    try:
        with open(p, "r") as f:
            data = json.load(f)
    except Exception as e:
        print(f"[WARN] Failed to read ICL file: {e}")
        return []
    examples = []
    if isinstance(data, list):
        for item in data:
            if isinstance(item, dict) and "image" in item and "pos" in item and isinstance(item["pos"], (list, tuple)) and len(item["pos"]) == 2:
                img = Path(item["image"])
                if not img.is_absolute():
                    # resolve relative to project root
                    img = (ROOT / img).resolve()
                examples.append((img, [float(item["pos"][0]), float(item["pos"][1])]))
    elif isinstance(data, dict):
        for k, v in data.items():
            if isinstance(v, (list, tuple)) and len(v) == 2:
                img = Path(k)
                if not img.is_absolute():
                    img = (ROOT / img).resolve()
                examples.append((img, [float(v[0]), float(v[1])]))
    else:
        print(f"[WARN] Unsupported ICL JSON format: {type(data)}")
    return examples

# ---- Helper: safe json parsing from model text ----
def extract_pos_angle(text: str):
    """
    Parse model output and return (pos_tuple_or_None, angle_or_None).
    Accepts raw model text; tries to parse JSON first, then regex fallbacks.
    """
    # Try JSON
    try:
        obj = json.loads(text)
        if isinstance(obj, dict):
            pos = obj.get("pos")
            ang = obj.get("angle")
            if isinstance(pos, (list, tuple)) and len(pos) == 2:
                px, py = float(pos[0]), float(pos[1])
            else:
                px = py = None
            if ang is not None:
                try:
                    angle = float(ang)
                except Exception:
                    angle = None
            else:
                angle = None
            if px is not None and py is not None:
                return (px, py), angle
    except Exception:
        pass
    # fallback: find "pos":[x,y]
    m = re.search(r'"pos"\s*:\s*\[\s*([-+]?\d*\.?\d+)\s*,\s*([-+]?\d*\.?\d+)\s*\]', text)
    px = py = None
    if m:
        try:
            px, py = float(m.group(1)), float(m.group(2))
        except Exception:
            px = py = None
    # fallback: find two numbers in brackets
    if px is None or py is None:
        m2 = re.search(r'\[\s*([-+]?\d*\.?\d+)\s*,\s*([-+]?\d*\.?\d+)\s*\]', text)
        if m2:
            try:
                px, py = float(m2.group(1)), float(m2.group(2))
            except Exception:
                px = py = None
    # Fallback: find angle
    angle = None
    m_ang = re.search(r'"angle"\s*:\s*([-+]?\d*\.?\d+)', text)
    if m_ang:
        try:
            angle = float(m_ang.group(1))
        except Exception:
            angle = None
    if px is not None and py is not None:
        return (px, py), angle
    return None

# --- Helper functions for geometry rendering and IoU ---
def _run(cmd, cwd=None):
    import subprocess, shlex
    if isinstance(cmd, str):
        cmd_list = shlex.split(cmd)
    else:
        cmd_list = cmd
    r = subprocess.run(cmd_list, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
    return r.returncode, r.stdout

def render_with_geometry(gt_json: Path, pred_json: Path, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    gt_png   = out_dir / "gt_render.png"
    pred_png = out_dir / "pred_render.png"
    overlay  = out_dir / "overlay_json_vs_pred.png"

    # 1) GT
    _run([PY, str(GEOM_PY),
          "--ann", str(gt_json),
          "--save_png", str(gt_png),
          "--no_show"])
    # 2) Pred
    _run([PY, str(GEOM_PY),
          "--ann", str(pred_json),
          "--save_png", str(pred_png),
          "--no_show"])
    # 3) Overlay
    _run([PY, str(GEOM_PY),
          "--overlay_gt_json",  str(gt_json),
          "--overlay_pred_json", str(pred_json),
          "--save_png", str(overlay),
          "--no_show"])
    return gt_png, pred_png, overlay

def iou_from_masks(gt_png: Path, pred_png: Path, save_debug: bool = False):
    # call iou_from_overlay.py in "two-image" mode
    cmd = [PY, str(IOU_PY), "--gt_img", str(gt_png), "--pred_img", str(pred_png)]
    if save_debug:
        cmd.append("--save_debug")
    rc, out = _run(cmd)
    # parse like: "[IoU] 0.5339  (intersection=..., union=...)"
    import re
    m = re.search(r"\[IoU\]\s+([0-9]*\.?[0-9]+)", out)
    return float(m.group(1)) if m else float("nan")

if args.backend == "qwen":
    print("[INFO] Loading Qwen2.5-VL...")
    from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
    from qwen_vl_utils import process_vision_info

    qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-3B-Instruct", device_map="cpu"
    ).eval()
    qwen_processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

elif args.backend == "openai":
    print("[INFO] Using OpenAI backend...")
    # Try OpenAI SDK v1
    try:
        from openai import OpenAI
        def _mask_key(k: str) -> str:
            if not k or len(k) < 8:
                return "(none)"
            return k[:6] + "..." + k[-4:]

        # Prefer CLI key; then OPENAI_API_KEY; then OPENROUTER_API_KEY for convenience
        api_key = (
            args.openai_api_key
            or os.getenv("OPENAI_API_KEY")
            or os.getenv("OPENROUTER_API_KEY")
        )
        base_url = args.openai_base_url or os.getenv("OPENAI_BASE_URL")

        # Auto-detect OpenRouter by key prefix if base_url not given
        if not base_url and api_key and api_key.startswith("sk-or-"):
            base_url = "https://openrouter.ai/api/v1"

        # Diagnostics: print where we think the creds are coming from (masked)
        print(f"[AUTH] base_url = {base_url or 'default (OpenAI)'}")
        print(f"[AUTH] api_key  = {_mask_key(api_key)}")
        if (base_url and "openrouter.ai" in base_url) and (not api_key or not api_key.startswith("sk-or-")):
            print("[WARN] Using OpenRouter base_url but API key does not start with 'sk-or-'. This will 401.")

        if not api_key:
            raise RuntimeError(
                "Missing API key. Pass --openai_api_key or set one of:\n"
                "  OPENAI_API_KEY  (for api.openai.com)\n"
                "  OPENROUTER_API_KEY (for openrouter.ai)\n"
            )

        client_kwargs = {"api_key": api_key}
        if base_url:
            client_kwargs["base_url"] = base_url

        # Optional headers for OpenRouter
        if base_url and "openrouter.ai" in base_url:
            # These help OpenRouter attribute requests to your app
            client_kwargs["default_headers"] = {
                "HTTP-Referer": "https://local.test",
                "X-Title": "qwen_centroid_demo"
            }

        openai_client = OpenAI(**client_kwargs)
        print(f"[INFO] OpenAI-compatible client ready (base_url={base_url or 'https://api.openai.com/v1'})")
    except Exception as e:
        print("[FATAL] Failed to import or init OpenAI SDK. Install with `pip install openai` and set OPENAI_API_KEY.")
        raise
else:
    raise ValueError(f"Unknown backend: {args.backend}")

pos_meaning = "the piece's template corner anchor (same corner as the dataset JSON)" if args.pos_def == "corner" else "the piece's centre of mass (geometric centroid in the 0..10 grid frame)"
angle_meaning = ("the piece's rotation angle in DEGREES, measured counterclockwise from the +X axis (to the right) in the 0..10 grid frame; range [-180,180] or [0,360) are both accepted")
sys_prompt = (
    "You are a geometry assistant.\n"
    "Task: Given ONE tangram piece on a 10×10 plotted grid image, return STRICT JSON ONLY:\n"
    '{\"pos\":[x,y]}\n'
    "- Coordinates MUST use the 0..10 axes drawn on the image (not pixel coords).\n"
    f"- 'pos' refers to {pos_meaning}.\n"
    "- Output pure JSON, no extra text, no code block."
)

all_imgs = sorted(IMG_DIR.glob("*.png"))
if len(all_imgs) == 0:
    raise FileNotFoundError(f"No PNGs in {IMG_DIR}")

# Build list of eligible images that have a matching GT with a valid pos
eligible = []
for img_path in all_imgs:
    base = img_path.stem
    gt_path = JSON_DIR / f"{base}.json"
    if not gt_path.exists():
        continue
    try:
        with open(gt_path, "r") as f:
            gt_obj = json.load(f)
        gt_pos = gt_obj.get("pos") or gt_obj.get("position")
        if isinstance(gt_pos, (list, tuple)) and len(gt_pos) == 2:
            eligible.append((img_path, gt_path, [float(gt_pos[0]), float(gt_pos[1])]))
    except Exception:
        continue

if not eligible:
    raise RuntimeError("No eligible (image,json with pos) pairs found.")

# Optionally auto-build ICL examples if no --icl_file is given
icl_examples = []
if not args.icl_file and args.icl_auto > 0:
    k_auto = min(args.icl_auto, len(eligible))
    random.shuffle(eligible)
    icl_pick = eligible[:k_auto]
    eligible = eligible[k_auto:]  # remove from pool so test set doesn't reuse
    for img_path, _gt_path, pos in icl_pick:
        icl_examples.append((img_path.resolve(), [float(pos[0]), float(pos[1])]))
    print(f"[ICL] Auto-sampled {len(icl_examples)} example(s) from dataset.")

# If a manual icl_file was provided, load it; this takes precedence and does not alter pool
if args.icl_file:
    loaded = load_icl_examples(Path(args.icl_file))
    if loaded:
        icl_examples = loaded[: max(0, int(args.icl_k))]
        print(f"[ICL] Using {len(icl_examples)} example(s) from {args.icl_file}")
    else:
        print("[ICL] No few-shot examples provided.")
else:
    # If auto ICL was used and icl_k set smaller, trim here
    if icl_examples and args.icl_k > 0:
        icl_examples = icl_examples[: min(args.icl_k, len(icl_examples))]

# Now sample test images from remaining eligible pool
if not eligible:
    raise RuntimeError("Pool exhausted after taking ICL examples; reduce --icl_auto or ensure more data.")

random.shuffle(eligible)
if args.all:
    selected = eligible
else:
    test_k = min(N_SAMPLES, len(eligible))
    selected = eligible[:test_k]
samples = [t[0] for t in selected]

rows = []
t0 = time.time()

def run_once(messages, img_for_openai: Path = None):
    if args.backend == "qwen":
        chat = qwen_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        images, videos = process_vision_info(messages)
        inputs = qwen_processor(text=[chat], images=images, videos=videos, return_tensors="pt", padding=True)
        out_ids = qwen_model.generate(**inputs, max_new_tokens=args.max_new, do_sample=False)
        trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
        text = qwen_processor.batch_decode(trimmed, skip_special_tokens=True)[0].strip()
        return text

    elif args.backend == "openai":
        # Convert messages to OpenAI (and OpenRouter) multimodal chat format
        # Use "input_text" and "input_image" types, and wrap image_url as an object {"url": ...}
        oai_msgs = []
        for m in messages:
            role = m["role"]
            parts = []
            for c in m["content"]:
                if c["type"] == "text":
                    # OpenAI/OR expects "input_text" for multimodal messages
                    parts.append({"type": "input_text", "text": c["text"]})
                elif c["type"] == "image":
                    img_path = Path(c["image"]) if not isinstance(c["image"], Path) else c["image"]
                    parts.append({
                        "type": "input_image",
                        "image_url": {"url": _img_to_data_url(img_path)}
                    })
            oai_msgs.append({"role": role, "content": parts})

        try:
            resp = openai_client.chat.completions.create(
                model=args.openai_model,
                messages=oai_msgs,
                temperature=0,
                max_tokens=args.max_new,
            )
            return resp.choices[0].message.content.strip()
        except Exception as e:
            # Surface helpful hints for common 401/404 issues
            err_msg = str(e)
            if "401" in err_msg or "Unauthorized" in err_msg or "No auth credentials" in err_msg:
                # base_url and api_key are in the outer scope
                raise RuntimeError(
                    "[AUTH ERROR] 401 from server. If you are using OpenRouter, ensure:\n"
                    "  - base_url = https://openrouter.ai/api/v1 (pass --openai_base_url)\n"
                    "  - API key starts with sk-or- (pass --openai_api_key or set OPENROUTER_API_KEY)\n"
                    "  - Model name is routable on OpenRouter (e.g., openai/gpt-4o-mini)\n"
                    f"Diagnostics: base_url={base_url or 'default'}, api_key={_mask_key(api_key)}"
                ) from e
            raise
    else:
        raise ValueError("Unsupported backend")

for img_path in samples:
    base = img_path.stem
    gt_path = JSON_DIR / f"{base}.json"
    if not gt_path.exists():
        print(f"[WARN] Missing GT for {base}, skip.")
        continue

    # read GT
    with open(gt_path, "r") as f:
        gt = json.load(f)
    gt_pos = None
    if isinstance(gt, dict):
        # try common keys
        if "pos" in gt:
            gt_pos = gt["pos"]
        elif "position" in gt:
            gt_pos = gt["position"]
    if not gt_pos or len(gt_pos) != 2:
        print(f"[WARN] GT json for {base} has no 'pos', skip.")
        continue
    gt_x, gt_y = float(gt_pos[0]), float(gt_pos[1])

    messages = [{"role":"system","content":[{"type":"text","text":sys_prompt}]}]
    # few-shot exemplars
    for ex_img_path, ex_pos in icl_examples:
        ex_img_str = str(ex_img_path if isinstance(ex_img_path, (str, Path)) else str(ex_img_path))
        messages.append({
            "role": "user",
            "content": [
                {"type":"image","image": ex_img_str},
                {"type":"text","text":"Return only {\"pos\":[x,y]} ."}
            ],
        })
        messages.append({
            "role": "assistant",
            "content": [
                {"type":"text","text": json.dumps({"pos":[float(ex_pos[0]), float(ex_pos[1])]}) }
            ],
        })
    # current query
    messages.append({
        "role":"user",
        "content":[
            {"type":"image","image":str(img_path.resolve())},
            {"type":"text","text":"Return only {\"pos\":[x,y]} ."}
        ],
    })

    text = run_once(messages, img_for_openai=img_path)

    # Use extract_pos_angle and get pred_pos
    tmp = extract_pos_angle(text)
    pred_pos = tmp[0] if tmp else None
    if pred_pos is None:
        err = None
        dx = dy = None
    else:
        px, py = pred_pos
        dx, dy = abs(px - gt_x), abs(py - gt_y)
        err = math.hypot(px - gt_x, py - gt_y)

    # Save prediction JSON next to per-image folder
    per_dir = SAVE_DIR / img_path.stem
    per_dir.mkdir(parents=True, exist_ok=True)
    pred_json_path = per_dir / "pred.json"

    # Seed pred JSON with GT metadata so geometry.py can render the shape
    pred_obj = {}
    for k in ("type", "size", "angle", "flip", "scale", "instance"):
        if isinstance(gt, dict) and (k in gt):
            pred_obj[k] = gt[k]

    # Position from the model (centre-of-mass in the 0..10 grid if --pos_def com)
    if pred_pos is not None:
        pred_obj["pos"] = [float(px), float(py)]
    else:
        pred_obj["pos"] = [None, None]



    with open(pred_json_path, "w") as f:
        json.dump(pred_obj, f)

    # Optional geometry renders + IoU
    iou_val = None
    if args.geom_render:
        gt_png, pred_png, overlay_png = render_with_geometry(gt_path, pred_json_path, per_dir)
        try:
            iou_val = iou_from_masks(gt_png, pred_png, save_debug=args.save_iou_debug)
            print(f"[IOU] {base} -> {iou_val:.4f} | gt={gt_x:.3f},{gt_y:.3f} pred={None if pred_pos is None else f'{pred_pos[0]:.3f},{pred_pos[1]:.3f}'} | overlay={overlay_png}")
            # also save a small marker file in the image folder
            with open(per_dir / "iou.txt", "w") as f:
                f.write(f"{iou_val:.6f}\n")
        except Exception as _e:
            print(f"[IOU] failed for {img_path.name}: {_e}")

    rows.append({
        "image": base + ".png",
        "gt_pos_x": gt_x, "gt_pos_y": gt_y,
        "pred_text": text,
        "pred_x": None if pred_pos is None else pred_pos[0],
        "pred_y": None if pred_pos is None else pred_pos[1],
        "abs_dx": dx, "abs_dy": dy, "l2_error": err,
        "iou_mask": iou_val
    })
    print(f"[DONE] {base}: GT=({gt_x:.3f},{gt_y:.3f})  PRED={pred_pos}  L2={err}  IOU={iou_val}")

# ---- Save CSV & summary ----
import csv
ts = time.strftime("%Y%m%d_%H%M%S")
csv_path = SAVE_DIR / f"pos_eval_{ts}.csv"
with open(csv_path, "w", newline="") as f:
    w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
    w.writeheader()
    w.writerows(rows)

# summary
valid = [r for r in rows if r["l2_error"] is not None]
mean_l2 = sum(r["l2_error"] for r in valid)/len(valid) if valid else float("nan")
mean_dx = sum(r["abs_dx"] for r in valid)/len(valid) if valid else float("nan")
mean_dy = sum(r["abs_dy"] for r in valid)/len(valid) if valid else float("nan")

# IoU statistics
valid_iou = [r["iou_mask"] for r in rows if r.get("iou_mask") is not None and not math.isnan(r["iou_mask"])]
mean_iou = (sum(valid_iou)/len(valid_iou)) if valid_iou else float("nan")

print("\n================ SUMMARY ================")
print(f"Samples: {len(rows)}  |  Valid: {len(valid)}")
print(f"Mean|dx| = {mean_dx:.4f}   Mean|dy| = {mean_dy:.4f}   Mean L2 = {mean_l2:.4f}")
print(f"Mean IoU (mask-from-geometry) = {mean_iou:.4f}  over {len(valid_iou)} items")
print(f"CSV saved -> {csv_path}")
print(f"Images source: {IMG_DIR}")
print(f"GT source    : {JSON_DIR}")
print(f"pos_def      : {args.pos_def}")
print(f"ICL examples : {len(icl_examples)}")
print(f"ICL mode     : {'auto' if (not args.icl_file and args.icl_auto>0) else ('file' if args.icl_file else 'none')}")
if args.backend == "openai":
    base_url_summary = os.getenv("OPENAI_BASE_URL") or (args.openai_base_url or "default")
    print(f"Backend      : {args.backend} (model={args.openai_model}, base_url={base_url_summary})")
else:
    print(f"Backend      : {args.backend} (Qwen/Qwen2.5-VL-3B-Instruct)")
print(f"Time used: {time.time()-t0:.1f}s")