import os, json, argparse, re, textwrap, math
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageOps
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextStreamer, StoppingCriteria, StoppingCriteriaList
from qwen_vl_utils import process_vision_info

# =========================
# 【PROMPT】主任务指令：严格 JSON 输出 + 固定 7 个拼块字段
# =========================
INSTR = textwrap.dedent("""
You are now a tangram puzzle solver.
Task: Output STRICT JSON, exactly 7 objects. Each object MUST have the following fields:
  name: one of [T1,T2,T3,T4,T5,P1,S1]
  type: one of [triangle,square,parallelogram]
  size: one of [large,medium,small,na]
  pos: [x,y] two floats (0..10 canvas grid)
  angle: float, MUST be chosen from {-135,-90,-45,0,45,90,135,180}
  flip: boolean
  scale: float
Return PURE JSON only, no extra text, no Markdown.

Constraints:
- Pieces must NOT share identical pos; avoid stacking at the same coordinates.
- Prefer non-zero angles when needed to match slanted edges; at least 4 pieces should use non-zero angles.
- Pieces should not significantly overlap each other; use the full canvas area to cover the silhouette.
- If your previous attempt had identical positions or all-zero angles, adjust pos/angle/flip/scale to spread pieces out and align with the outline.
""").strip()

# =========================
# 【PROMPT】小抄（可选，用于 --teach）：告诉模型 7 个拼块及取值范围
# =========================
CHEATSHEET = """
Tangram 7 fixed pieces:
- T1: triangle large
- T2: triangle large
- T3: triangle medium
- T4: triangle small
- T5: triangle small
- P1: parallelogram na
- S1: square na
"""

# ---------- CLI ----------
def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--params_dir", required=True)
    ap.add_argument("--outlines_dir", required=True)
    ap.add_argument("--input_dir", required=True)  # 彩色 params_vis
    ap.add_argument("--train_ids", nargs="+", required=True)
    ap.add_argument("--test_id", required=True)
    ap.add_argument("--sequential", action="store_true")
    ap.add_argument("--teach", action="store_true")
    ap.add_argument("--max_new", type=int, default=512)
    ap.add_argument("--save_pred", default="")
    return ap.parse_args()

ARGS = parse_args()

# ---------- Helper ----------
def must_exist(p): 
    if not os.path.exists(p): raise FileNotFoundError(p)

def load_json_str(path):
    return json.dumps(json.load(open(path,"r",encoding="utf-8")), ensure_ascii=False)

def outline_id_paths(_id: str):
    return os.path.join(ARGS.outlines_dir, f"{_id}.png"), os.path.join(ARGS.params_dir, f"{_id}.json")

# ---------- IoU 评估（用实际七巧板多边形 + 坐标归一化到像素） ----------
def _poly_triangle(size):
    # 以直角等腰三角为模板，边长（单位网格）：
    base = {"large": 1.0, "medium": math.sqrt(2) / 2.0, "small": 0.5}[size]
    # 局部坐标（右上象限），后续会旋转/翻转/平移
    return [(0.0, 0.0), (base, 0.0), (0.0, base)]

def _poly_square():
    # 1x1 单位正方形
    return [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]

def _poly_parallelogram():
    # 1x0.5 平行四边形（简单模板）
    return [(0.0, 0.0), (1.0, 0.0), (1.5, 0.5), (0.5, 0.5)]

def _apply_transform(poly, pos, angle_deg, flip, scale, px, offx=0.0, offy=0.0):
    """
    Apply flip/scale/rotate and then map from grid units to pixel coordinates.
    Adds an (offx, offy) pixel offset so we can align to the silhouette bbox.
    """
    ax = math.radians(float(angle_deg))
    ca, sa = math.cos(ax), math.sin(ax)
    out = []
    for (x, y) in poly:
        if flip:
            x = -x
        x *= float(scale)
        y *= float(scale)
        xr = x * ca - y * sa
        yr = x * sa + y * ca
        X = (xr + float(pos[0])) * px + offx
        Y = (yr + float(pos[1])) * px + offy
        out.append((X, Y))
    return out

def _pieces_to_polys_in_pixels(pieces, w, h, bbox=None):
    """
    Convert pieces (0..10 grid) to pixel polygons.
    If bbox is provided as (x0,y0,x1,y1), map the 0..10 grid into that bbox.
    Otherwise, fall back to filling the whole image with px = max(w,h)/10.
    """
    if bbox is not None:
        x0, y0, x1, y1 = bbox
        bw = max(1, int(x1 - x0))
        bh = max(1, int(y1 - y0))
        # keep aspect by using the smaller step
        px = min(bw, bh) / 10.0
        offx, offy = float(x0), float(y0)
    else:
        px = max(w, h) / 10.0
        offx = offy = 0.0

    polys = []
    for p in pieces:
        ptype = p["type"]
        if ptype == "triangle":
            poly = _poly_triangle(p["size"])
        elif ptype == "square":
            poly = _poly_square()
        else:
            poly = _poly_parallelogram()
        poly_px = _apply_transform(poly, p["pos"], p["angle"], p["flip"], p["scale"], px, offx, offy)
        polys.append(poly_px)
    return polys

def render_mask(pieces, w, h, bbox=None):
    img = Image.new("L", (w, h), 0)
    draw = ImageDraw.Draw(img)
    for poly in _pieces_to_polys_in_pixels(pieces, w, h, bbox=bbox):
        draw.polygon(poly, fill=1)
    return np.array(img, dtype=np.uint8)

def load_silhouette(path, thresh=128):
    arr = np.array(Image.open(path).convert("L"))
    mask = (arr < thresh).astype(np.uint8)
    # auto-fix if foreground is almost empty or almost full
    fg_ratio = mask.mean()
    if fg_ratio < 0.01 or fg_ratio > 0.99:
        mask = (arr > thresh).astype(np.uint8)
    return mask

def foreground_bbox(mask):
    ys, xs = np.where(mask > 0)
    if xs.size == 0 or ys.size == 0:
        # Fallback to the full image
        h, w = mask.shape
        return (0, 0, w, h)
    return (int(xs.min()), int(ys.min()), int(xs.max()) + 1, int(ys.max()) + 1)

def iou_score(pred, gt):
    pred_b = pred.astype(bool)
    gt_b = gt.astype(bool)
    inter = np.logical_and(pred_b, gt_b).sum()
    union = np.logical_or(pred_b, gt_b).sum()
    return inter / max(1, union)

def geometry_metrics(pred_mask, gt_mask):
    pred_b = pred_mask.astype(bool)
    gt_b = gt_mask.astype(bool)
    inter = np.logical_and(pred_b, gt_b).sum()
    union = np.logical_or(pred_b, gt_b).sum()
    iou = inter / max(1, union)
    overflow = np.logical_and(pred_b, ~gt_b).sum() / pred_b.size
    undercov = np.logical_and(gt_b, ~pred_b).sum() / pred_b.size
    return iou, overflow, undercov

# ---------- 模型加载 ----------
print("[INFO] Loading Qwen2.5-VL...")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", device_map="cpu").eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

# =========================
# 【PROMPT】对话构建：few-shot（彩色图+JSON） → 最终任务（新图）
# =========================
messages=[]

# Few-shot 部分
if ARGS.sequential:
    for _id in ARGS.train_ids:
        img=os.path.join(ARGS.input_dir,f"{_id}.png"); js=os.path.join(ARGS.params_dir,f"{_id}.json")
        must_exist(img); must_exist(js)
        ex_json=load_json_str(js)
        messages.append({"role":"user","content":[{"type":"image","image":img},{"type":"text","text":INSTR+"\nSolved example (JSON below):"}]})
        messages.append({"role":"assistant","content":[{"type":"text","text":ex_json}]})
# 【LOG】few-shot 条数统计（加入对话的示例数量）
if ARGS.sequential:
    print(f"[LOG] 已向模型对话加入 few-shot 示例数量: {len(ARGS.train_ids)}")

# 最终任务
test_img=os.path.join(ARGS.outlines_dir,f"{ARGS.test_id}.png") # 喂给模型：黑色轮廓
# =========================
# 【LOG】本次运行的数据使用说明（透明化）
# - few-shot（训练示例）：会把「彩色 params_vis 图片 + 对应 JSON」喂给模型作为参考
# - test（测试）：只把「黑色 outline 图片」喂给模型；不会把 test 的真值 JSON 加入对话
# - 评估：只在生成完 JSON 之后，才会用 test 的真值 JSON/或 silhouette 来算 IoU（不喂给模型）
# =========================
test_gt_json = os.path.join(ARGS.params_dir, f"{ARGS.test_id}.json")  # 仅用于评估（不加入模型输入）
print("\n[LOG] ===== 数据使用清单 =====")
print("[LOG] Few-shot 训练示例（会喂给模型）:")
for _id in ARGS.train_ids:
    img_path = os.path.join(ARGS.input_dir, f"{_id}.png")
    js_path  = os.path.join(ARGS.params_dir,  f"{_id}.json")
    # 存在性检查，方便排错
    try:
        must_exist(img_path); must_exist(js_path)
    except Exception as _e:
        print(f"[WARN] few-shot 示例缺失文件: {_id} -> {img_path} 或 {js_path} 不存在")
    print(f"  - image(params_vis): {img_path}")
    print(f"    json (ground-truth，作为示例喂给模型): {js_path}")

print("\n[LOG] 测试阶段（推理）:")
try:
    must_exist(test_img)
except Exception as _e:
    print(f"[ERROR] 测试 outline 图像不存在: {test_img}")
print(f"  - 输入给模型的 outline 图像: {test_img}")

# 注意：这里不读取 test_gt_json，只打印路径，强调“仅用于评估”
print("\n[LOG] 评估阶段（不会喂给模型）:")
print(f"  - 真值 JSON（仅用于评估，不加入对话）: {test_gt_json}")
print("[LOG] =======================\n")

must_exist(test_img)
final_prompt = INSTR
if ARGS.teach:
    final_prompt += "\n" + CHEATSHEET
final_prompt += "\nNow solve this outline and return JSON."

messages.append({"role":"user","content":[{"type":"image","image":test_img},{"type":"text","text":final_prompt}]})

# ---------- 构建输入 ----------
chat_text=processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs,video_inputs=process_vision_info(messages)
inputs=processor(text=[chat_text],images=image_inputs,videos=video_inputs,return_tensors="pt",padding=True)
inputs={k:v.to("cpu") if isinstance(v,torch.Tensor) else v for k,v in inputs.items()}

#
# ---------- 生成 ----------
# 【LOG】对话摘要：用于确认不会把 test 的真值 JSON 混入
#   - 当前对话 turn 数（条目数）仅包含：few-shot 的（图像+JSON）对，和测试的 outline 图像 + 任务提示词
#   - 不会包含 test 的真值 JSON
print(f"[RUN] Generating... | 对话条目数={len(messages)} | teach={'ON' if ARGS.teach else 'OFF'} | sequential={'ON' if ARGS.sequential else 'OFF'}")
stopper=StoppingCriteriaList([])
streamer=TextStreamer(processor.tokenizer,skip_special_tokens=True)

with torch.no_grad():
    out_ids = model.generate(
        **inputs,
        max_new_tokens=ARGS.max_new,
        streamer=streamer,
        do_sample=True,          # enable sampling to avoid mode collapse
        temperature=0.3,         # low temperature for stability
        top_p=0.9,               # nucleus sampling
    )
trimmed=[o[len(i):] for i,o in zip(inputs["input_ids"],out_ids)]
output=processor.batch_decode(trimmed,skip_special_tokens=True)[0].strip()

print("\n=== RAW OUTPUT ===\n",output)

# ---------- 解析 JSON ----------
m=re.search(r"\[.*\]",output,re.S)
if not m: print("[FAIL] 无法解析 JSON"); exit(0)
try:
    pieces=json.loads(m.group(0))
except Exception as e:
    print("[FAIL] JSON 解析失败:",e); exit(0)
# 【LOG】JSON 基本信息
try:
    _names = [p.get("name","?") for p in pieces if isinstance(p, dict)]
    print(f"[LOG] 解析到的 piece 数量: {len(pieces)} | 名称: {_names}")
except Exception:
    pass

def _anti_collapse_need_retry(pieces, min_unique=5, tol=1e-6):
    # Count unique positions and check if many pieces are stacked at the same coordinates
    coords = [(round(float(p.get('pos',[0,0])[0]),6), round(float(p.get('pos',[0,0])[1]),6)) for p in pieces]
    uniq = len(set(coords))
    return uniq < min_unique

if _anti_collapse_need_retry(pieces):
    # Add a corrective turn and re-generate once to discourage identical positions
    feedback = (
        "Geometry feedback: Many pieces share identical positions. "
        "Spread pieces across the silhouette; avoid stacking. "
        "Adjust only pos/angle/flip/scale and return STRICT JSON only."
    )
    # Append assistant's last JSON and user's feedback as a new turn
    messages.append({"role":"assistant","content":[{"type":"text","text":json.dumps(pieces, ensure_ascii=False)}]})
    messages.append({"role":"user","content":[{"type":"text","text":feedback}]})

    # Rebuild inputs and generate again with the same decoding settings
    chat_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[chat_text], images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
    inputs = {k: (v.to("cpu") if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}

    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=ARGS.max_new,
            do_sample=True,
            temperature=0.35,
            top_p=0.9,
        )
    trimmed=[o[len(i):] for i,o in zip(inputs["input_ids"],out_ids)]
    output=processor.batch_decode(trimmed,skip_special_tokens=True)[0].strip()
    print("\n=== RAW OUTPUT (ANTI-COLLAPSE RETRY) ===\n", output)
    m=re.search(r"\[.*\]",output,re.S)
    if not m:
        print("[FAIL] 无法解析 JSON (retry)")
        exit(0)
    try:
        pieces=json.loads(m.group(0))
    except Exception as e:
        print("[FAIL] JSON 解析失败 (retry):", e)
        exit(0)

def _clone_pieces(pcs):
    return [dict(p) for p in pcs]

def _replace_and_render(pcs, idx, cand, w, h, bbox):
    pcs2 = _clone_pieces(pcs)
    pcs2[idx] = cand
    return render_mask(pcs2, w, h, bbox=bbox), pcs2

def _discrete_refine(pieces, gt_mask, bbox, angle_set=None, pos_delta=0.5, pos_span=1.0, scale_set=None, try_flip=True, iters=2):
    """
    Small discrete local search around each piece to improve IoU.
    Coordinates are in 0..10 grid units (same as model output).
    """
    h, w = gt_mask.shape
    x0, y0, x1, y1 = bbox
    if angle_set is None:
        angle_set = [-135,-90,-45,0,45,90,135,180]
    if scale_set is None:
        scale_set = [0.8, 1.0, 1.2]

    # Precompute baseline
    best_mask = render_mask(pieces, w, h, bbox=bbox)
    best_iou, best_ov, best_un = geometry_metrics(best_mask[y0:y1, x0:x1], gt_mask[y0:y1, x0:x1])

    for _ in range(iters):
        improved = False
        for i, p in enumerate(pieces):
            base = dict(p)
            x00, y00 = float(base["pos"][0]), float(base["pos"][1])
            sc0 = float(base["scale"])
            fl0 = bool(base["flip"])
            ang0 = float(base["angle"])

            # Candidate grids
            deltas = [-pos_span, -pos_delta, 0.0, pos_delta, pos_span]
            pos_cands = [(x00+dx, y00+dy) for dx in deltas for dy in deltas]
            ang_cands = angle_set
            sc_cands  = scale_set
            flip_cands = [fl0, (not fl0)] if try_flip else [fl0]

            local_best = (best_iou, None, None)  # (iou, mask, cand_piece)

            for (xx, yy) in pos_cands:
                # keep in 0..10 grid
                xx = max(0.0, min(10.0, xx))
                yy = max(0.0, min(10.0, yy))
                for aa in ang_cands:
                    for ss in sc_cands:
                        for ff in flip_cands:
                            cand = dict(base)
                            cand["pos"] = [xx, yy]
                            cand["angle"] = float(aa)
                            cand["scale"] = float(ss)
                            cand["flip"]  = bool(ff)
                            mask_cand, pcs_cand = _replace_and_render(pieces, i, cand, w, h, bbox)
                            iou_cand, ov_cand, un_cand = geometry_metrics(mask_cand[y0:y1, x0:x1], gt_mask[y0:y1, x0:x1])
                            if iou_cand > local_best[0] + 1e-6:  # strict improve
                                local_best = (iou_cand, mask_cand, pcs_cand)
            if local_best[1] is not None:
                # accept local improvement
                pieces = local_best[2]
                best_mask = local_best[1]
                best_iou = local_best[0]
                improved = True
        if not improved:
            break
    # final metrics
    best_iou, best_ov, best_un = geometry_metrics(best_mask[y0:y1, x0:x1], gt_mask[y0:y1, x0:x1])
    return pieces, best_iou, best_ov, best_un

#
# ---------- 评估 ----------
gt_mask = load_silhouette(test_img)
h, w = gt_mask.shape
bbox = foreground_bbox(gt_mask)
pred_mask_full = render_mask(pieces, w, h, bbox=bbox)
x0, y0, x1, y1 = bbox
gt_crop = gt_mask[y0:y1, x0:x1]
pred_crop = pred_mask_full[y0:y1, x0:x1]
iou, ov, un = geometry_metrics(pred_crop, gt_crop)
print(f"\n[GEOMETRY] (bbox-aligned) IoU={iou:.3f} | overflow={ov:.4f} | undercov={un:.4f}")

# If geometry is poor, run discrete local refinement
if iou < 0.30:
    print("[REFINE] Running discrete local search...")
    pieces, iou2, ov2, un2 = _discrete_refine(
        pieces, gt_mask, bbox,
        angle_set=[-135,-90,-45,0,45,90,135,180],
        pos_delta=0.5, pos_span=1.0,
        scale_set=[0.8, 1.0, 1.2],
        try_flip=True, iters=2
    )
    if iou2 > iou + 1e-6:
        iou, ov, un = iou2, ov2, un2
        pred_mask_full = render_mask(pieces, w, h, bbox=bbox)
        pred_crop = pred_mask_full[y0:y1, x0:x1]
    print(f"[REFINE] After refine: IoU={iou:.3f} | overflow={ov:.4f} | undercov={un:.4f}")

# If still poor, try one more guided retry with numeric feedback
if iou < 0.25:
    feedback = (
        f"Geometry feedback: IoU={iou:.3f}, overflow={ov:.4f}, undercov={un:.4f}. "
        "Increase coverage of the silhouette with non-zero angles from {-135,-90,-45,45,90,135,180}. "
        "Use flips where needed (especially P1 / small triangles). "
        "Avoid overlapping pieces and avoid identical positions. "
        "Adjust only pos/angle/flip/scale and return STRICT JSON only."
    )
    # Append assistant's last JSON and user's feedback as a new turn
    messages.append({"role":"assistant","content":[{"type":"text","text":json.dumps(pieces, ensure_ascii=False)}]})
    messages.append({"role":"user","content":[{"type":"text","text":feedback}]})

    chat_text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[chat_text], images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
    inputs = {k: (v.to("cpu") if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=ARGS.max_new,
            do_sample=True,
            temperature=0.35,
            top_p=0.9,
        )
    trimmed=[o[len(i):] for i,o in zip(inputs["input_ids"],out_ids)]
    output=processor.batch_decode(trimmed,skip_special_tokens=True)[0].strip()
    print("\n=== RAW OUTPUT (GEOMETRY FEEDBACK RETRY) ===\n", output)
    m=re.search(r"\[.*\]",output,re.S)
    if m:
        try:
            pieces=json.loads(m.group(0))
            # re-eval and refine again
            pred_mask_full = render_mask(pieces, w, h, bbox=bbox)
            pred_crop = pred_mask_full[y0:y1, x0:x1]
            gt_crop = gt_mask[y0:y1, x0:x1]
            iou, ov, un = geometry_metrics(pred_crop, gt_crop)
            if iou < 0.30:
                print("[REFINE] Running discrete local search after feedback...")
                pieces, iou2, ov2, un2 = _discrete_refine(pieces, gt_mask, bbox, iters=2)
                if iou2 > iou + 1e-6:
                    iou, ov, un = iou2, ov2, un2
            print(f"[GEOMETRY] After feedback: IoU={iou:.3f} | overflow={ov:.4f} | undercov={un:.4f}")
        except Exception as e:
            print("[WARN] Feedback retry JSON parse failed:", e)

# 【LOG】保存结果说明：
# - 这里保存的是“模型推理的 JSON”，方便复现/可视化
# - 并非 test 的真值 JSON；真值 JSON 只在几何评估阶段读取用于对比（不喂给模型）
if ARGS.save_pred:
    os.makedirs(os.path.dirname(ARGS.save_pred), exist_ok=True)
    with open(ARGS.save_pred, "w", encoding="utf-8") as f:
        f.write(json.dumps({"pieces": pieces}, ensure_ascii=False))
    print("[SAVE] 推理结果已写入:", ARGS.save_pred)