# -*- coding: utf-8 -*-
"""
单图测试：给定一张彩图（params_vis），让 Qwen 输出 7-piece JSON
"""

import os, json, re, torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, StoppingCriteria, StoppingCriteriaList

ROSTER = {
    "T1": ("triangle","large"),
    "T2": ("triangle","large"),
    "T3": ("triangle","medium"),
    "T4": ("triangle","small"),
    "T5": ("triangle","small"),
    "P1": ("parallelogram","na"),
    "S1": ("square","na"),
}

def minimal_repair(obj):
    """Coerce a near-miss JSON into strict 7-piece schema with fixed roster."""
    items = obj.get("pieces", obj) if isinstance(obj, dict) else obj
    if not isinstance(items, list):
        raise ValueError("Model output is not a list")
    # map short names
    fixed = {}
    for p in items:
        if not isinstance(p, dict): 
            continue
        nm = p.get("name")
        if nm in ("P","p"): nm = "P1"
        if nm in ("S","s"): nm = "S1"
        if nm not in ROSTER: 
            continue
        if nm in fixed: 
            continue
        fixed[nm] = p

    out = []
    for nm,(ty,sz) in ROSTER.items():
        p = fixed.get(nm, {})
        pos = p.get("pos", [0,0]); angle = p.get("angle", 0.0); flip = bool(p.get("flip", False)); scale = float(p.get("scale", 1.0))
        try:
            pos = [float(pos[0]), float(pos[1])]
            angle = float(angle); scale = float(scale)
        except Exception:
            pos = [0.0, 0.0]; angle = 0.0; scale = 1.0
        out.append({
            "name": nm, "type": ty, "size": sz,
            "pos": pos, "angle": angle, "flip": flip, "scale": scale
        })
    return {"pieces": out}

# ========= 配置 =========
INPUT_IMG = os.path.expanduser('~/Desktop/qwen_vl_demo/kilogram-main/dataset/params_vis/page1-1.png')
SAVE_JSON = os.path.expanduser('~/Desktop/qwen_vl_demo/kilogram-main/dataset/preds/page1-1.pred.json')

INSTR = """
You are a tangram puzzle solver. Given ONE colored solution image on a 10×10 grid,
OUTPUT STRICT JSON ONLY: an array of EXACTLY 7 objects, each with the fields:
  name ∈ ["T1","T2","T3","T4","T5","P1","S1"]
  type ∈ ["triangle","square","parallelogram"]
  size ∈ ["large","medium","small","na"]
  pos  : [x, y]  (two floats in the SAME 0–10 grid as the image)
  angle: float in degrees (use multiples of 45 when possible)
  flip : boolean
  scale: float

FIXED ROSTER (must match exactly):
  T1: (triangle, large)
  T2: (triangle, large)
  T3: (triangle, medium)
  T4: (triangle, small)
  T5: (triangle, small)
  P1: (parallelogram, na)
  S1: (square, na)

Rules:
- Return PURE JSON ONLY (no prose, no markdown fences).
- Use the full roster once each; names must be unique and cover all 7 pieces.
- "square" and "parallelogram" MUST have size "na".
- Coordinates must be in the 0–10 grid as seen on the image; do NOT use pixels.
- Do NOT stack all pieces at the same pos; pieces should tile the silhouette.
"""

# ========= JSON 截断器 =========
class BalancedJSONStop(StoppingCriteria):
    """Stop when a balanced JSON array/object has been completed.
    Tracks depth outside string literals. Returns True only when depth==0
    and the last non-space character is '}' or ']'.
    """
    def __init__(self, tokenizer, start_len: int):
        super().__init__()
        self.tok = tokenizer
        self.start_len = start_len

    def __call__(self, input_ids, scores, **kwargs):
        cur = input_ids[0, self.start_len:].tolist()
        text = self.tok.decode(cur, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        depth = 0
        in_string = False
        escape = False
        started = False
        for ch in text:
            if escape:
                escape = False
                continue
            if ch == '\\':
                escape = True
                continue
            if ch == '"':
                in_string = not in_string
                continue
            if in_string:
                continue
            if ch in '[{':
                depth += 1
                started = True
            elif ch in ']}':
                depth = max(depth - 1, 0)
        if started and depth == 0:
            tail = text.rstrip()
            return len(tail) > 0 and tail[-1] in '}]'
        return False

# ========= 主流程 =========
def main():
    print("[INFO] Loading Qwen2.5-VL ...")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2.5-VL-3B-Instruct", device_map="cpu", torch_dtype=torch.float32
    ).eval()
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

    # 构造输入消息
    messages = [
        {"role": "user", "content": [
            {"type": "image", "image": INPUT_IMG},
            {"type": "text", "text": INSTR}
        ]}
    ]

    # 打开图像
    with Image.open(INPUT_IMG) as im:
        imgs = [im.convert("RGB")]

    chat = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(text=[chat], images=imgs, padding=True, return_tensors="pt")
    dev = next(model.parameters()).device
    inputs = {k:(v.to(dev) if isinstance(v, torch.Tensor) else v) for k,v in inputs.items()}

    stopper = StoppingCriteriaList([BalancedJSONStop(processor.tokenizer, inputs["input_ids"].shape[1])])
    with torch.no_grad():
        out_ids = model.generate(**inputs, max_new_tokens=900, do_sample=False, stopping_criteria=stopper)

    gen = processor.batch_decode(out_ids[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0].strip()

    # Strip accidental code fences if present
    if gen.startswith("```"):
        gen = re.sub(r"^```(?:json)?\s*|\s*```$", "", gen, flags=re.DOTALL).strip()

    try:
        obj = json.loads(gen)
    except Exception as e:
        print("[FAIL] JSON parse error:", e)
        print("Raw text:", gen)
        return

    try:
        obj = minimal_repair(obj)
    except Exception as e:
        print("[FAIL] Schema repair failed:", e)
        print("Raw JSON:", json.dumps(obj, ensure_ascii=False) if isinstance(obj, (list,dict)) else str(obj))
        return

    try:
        print("[DEBUG] First repaired piece:", obj["pieces"][0])
    except Exception:
        pass

    os.makedirs(os.path.dirname(SAVE_JSON), exist_ok=True)
    with open(SAVE_JSON, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)
    print(f"[SAVE] wrote {SAVE_JSON}")

if __name__ == "__main__":
    main()