import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
from PIL import Image
import matplotlib.image as mpimg
import os
import math
from typing import List, Dict, Any
SQRT2 = math.sqrt(2.0)





class TangramPiece:
    def __init__(self, name, shape, vertices, color='blue'):
        self.name = name
        self.shape = shape
        self.original_vertices = np.array(vertices)
        self.transformed_vertices = self.original_vertices.copy()
        self.color = color

    def transform(self, position, angle_deg, flip=False, scale=1.0):
        """按顺序执行: (以质心为原点) flip → rotate → scale → translate到给定质心(position)

        约定：
        - position: 目标质心坐标（画布[0,10]×[0,10]）
        - angle_deg: 逆时针角度（度）
        - flip: 沿 x 轴镜像（在质心坐标系里）
        - scale: 等比缩放倍数
        """
        angle_rad = np.deg2rad(angle_deg)
        R = np.array([[np.cos(angle_rad), -np.sin(angle_rad)],
                      [np.sin(angle_rad),  np.cos(angle_rad)]])

        # 以原始顶点的质心为原点建立局部坐标系
        V0 = self.original_vertices.astype(float)
        c0 = V0.mean(axis=0)
        V = V0 - c0

        # 可选镜像（绕局部坐标系原点）
        if flip:
            V = V @ np.array([[-1, 0], [0, 1]]).T

        # 旋转（绕局部原点）
        V = V @ R.T

        # 等比缩放（绕局部原点）
        V = V * float(scale)

        # 平移到目标质心
        T = np.asarray(position, dtype=float)
        self.transformed_vertices = V + T

    def draw(self, ax):
        polygon = Polygon(self.transformed_vertices, closed=True, facecolor=self.color, alpha=0.6, edgecolor='black')
        # Using facecolor avoids Matplotlib warning: setting 'color' overrides edge/facecolor
        ax.add_patch(polygon)
        cx, cy = self.transformed_vertices.mean(axis=0)
        ax.text(cx, cy, self.name, ha='center', va='center', fontsize=8, color='black')

    def set_by_instruction(self, instruction):
        pos = instruction.get("pos", [0, 0])
        angle = instruction.get("angle", 0)
        flip = instruction.get("flip", False)
        scale = instruction.get("scale", 1.0)
        self.transform(pos, angle, flip, scale)

# 七巧板7块定义（标准比例；与 full.json 兼容的模板尺寸）
tangram_pieces = [
    # 两个大三角：直角等腰三角形，直角边 = 4*sqrt(2)
    TangramPiece("T1", "big_triangle", [[0, 0], [4*SQRT2, 0], [0, 4*SQRT2]], 'red'),
    TangramPiece("T2", "big_triangle", [[0, 0], [4*SQRT2, 0], [0, 4*SQRT2]], 'orange'),

    # 中三角：直角边 = 4
    TangramPiece("T3", "medium_triangle", [[0, 0], [4, 0], [0, 4]], 'green'),

    # 两个小三角：直角边 = 2*sqrt(2)
    TangramPiece("T4", "small_triangle", [[0, 0], [2*SQRT2, 0], [0, 2*SQRT2]], 'blue'),
    TangramPiece("T5", "small_triangle", [[0, 0], [2*SQRT2, 0], [0, 2*SQRT2]], 'purple'),

    # 正方形：边长 = 2*sqrt(2)
    TangramPiece("S",  "square", [[0, 0], [2*SQRT2, 0], [2*SQRT2, 2*SQRT2], [0, 2*SQRT2]], 'brown'),

    # 平行四边形：与常见定义一致（底边=4，高方向位移=2）
    TangramPiece("P",  "parallelogram", [[0, 0], [4, 0], [6, 2], [2, 2]], 'pink')
]

# (type, instance) -> internal name used above
_TYPE_INSTANCE_TO_NAME = {
    ("big_triangle", "L1"): "T1",
    ("big_triangle", "L2"): "T2",
    ("medium_triangle", None): "T3",
    ("small_triangle", "S1"): "T4",
    ("small_triangle", "S2"): "T5",
    ("square", None): "S",
    ("parallelogram", None): "P",
}


_NAME_TO_PIECE = {p.name: p for p in tangram_pieces}


# Fallback: map (type,size,instance) to our internal template name when 'name' is absent.
def _name_from_type_size(t: str, size: str | None, inst: str | None = None) -> str | None:
    """Fallback: map (type,size,instance) to our internal template name when 'name' is absent."""
    if not t:
        return None
    t = str(t).lower()
    size = (str(size).lower() if size is not None else "na")
    inst = (str(inst) if inst is not None else None)

    if t == "triangle":
        # Big triangles have two instances (L1/L2). If missing, default to T1.
        if size in ("big", "large"):
            if inst in ("L2", "l2", "2"):
                return "T2"
            return "T1"
        if size == "medium":
            return "T3"
        if size == "small":
            # Two small triangles (S1/S2). Default to T4.
            if inst in ("S2", "s2", "2"):
                return "T5"
            return "T4"
        # If size missing for triangle, fallback to medium template.
        return "T3"
    if t == "square":
        return "S"
    if t == "parallelogram":
        return "P"
    return None

def _coerce_to_response(ann: Any) -> List[Dict[str, Any]]:
    """将多种JSON格式统一为内部 response:list[{name,pos,angle,flip,scale}]。
    支持：
    - 严格格式: [ {"name","type","size","pos","angle","flip","scale"}, ... ]
    - 包裹格式: {"pieces": [...]}
    - 旧格式:    [ {"type":"big_triangle"|... , "instance":"L1"|..., "pos", "angle", "flip", "scale"}, ... ]
    """
    # 如果是包裹形式，取出 pieces
    if isinstance(ann, dict) and "pieces" in ann:
        ann = ann["pieces"]

    # If it's a single dict that looks like one piece, wrap to list
    if isinstance(ann, dict) and ("name" in ann or "type" in ann or "pos" in ann):
        ann = [ann]

    response: List[Dict[str, Any]] = []
    if isinstance(ann, list) and ann and isinstance(ann[0], dict):
        # 判断是否为严格格式（有 name 字段）
        if all("name" in it for it in ann):
            NAME_NORM = {"P1": "P", "S1": "S"}
            for it in ann:
                nm = it.get("name")
                # 兼容严格命名里的 P1/S1 到内部模板名 P/S
                if isinstance(nm, str) and nm in NAME_NORM:
                    nm = NAME_NORM[nm]
                pos = it.get("pos", [0, 0])
                if isinstance(pos, (list, tuple)) and len(pos) == 2:
                    pos = [float(pos[0]), float(pos[1])]
                response.append({
                    "name": nm,
                    "pos": pos,
                    "angle": float(it.get("angle", 0)),
                    "flip": bool(it.get("flip", False)),
                    "scale": float(it.get("scale", 1.0)),
                })
            return response
        # 否则当作旧格式（type/instance/size）
        for item in ann:
            t = item.get("type")
            inst = item.get("instance")
            size = item.get("size")
            key = (t, inst)
            name = _TYPE_INSTANCE_TO_NAME.get(key)
            if name is None:
                name = _TYPE_INSTANCE_TO_NAME.get((t, None))
            # Fallback: allow (type,size[,instance]) without explicit 'name'
            if not name:
                name = _name_from_type_size(t, size, inst)
            if not name:
                continue
            pos = item.get("pos", [0, 0])
            if isinstance(pos, (list, tuple)) and len(pos) == 2:
                pos = [float(pos[0]), float(pos[1])]
            response.append({
                "name": name,
                "pos": pos,
                "angle": float(item.get("angle", 0)),
                "flip": bool(item.get("flip", False)),
                "scale": float(item.get("scale", 1.0)),
            })
    else:
        # ann 不是列表或元素不是字典，返回空列表，调用方自行处理
        return []
    return response

def draw_from_annotation_json(json_path, outline_png_path=None, title=None, flip_y: bool=False, show: bool=True):
    """从 annotations JSON 渲染七巧板。可选叠加 outline PNG 便于比对。"""
    import json
    with open(json_path, "r", encoding="utf-8") as f:
        ann = json.load(f)

    response = _coerce_to_response(ann)
    # 可选Y轴翻转（把[0,10]坐标的y替换为 10 - y）
    if flip_y:
        for r in response:
            if isinstance(r.get("pos"), (list, tuple)) and len(r["pos"]) == 2:
                x, y = r["pos"]
                r["pos"] = [float(x), 10.0 - float(y)]

    fig, ax = plt.subplots()
    ax.set_aspect('equal')
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    ax.set_xticks(np.arange(0, 11, 1))
    ax.set_yticks(np.arange(0, 11, 1))
    ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', alpha=0.5)

    # 叠加 outline（若 flip_y=True，则先把 PNG 上下翻转，使其与坐标轴 y 向上保持一致）
    if outline_png_path and os.path.exists(outline_png_path):
        img = mpimg.imread(outline_png_path)
        # 转灰度仅用于可视化，不改变几何
        if img.ndim == 3:
            img = img[..., :3].mean(axis=2)
        if flip_y:
            # PNG 像素坐标原点在左上，因此需要翻转到数学坐标系（原点左下）
            import numpy as _np
            img = _np.flipud(img)
        ax.imshow(img, extent=[0, 10, 0, 10], origin='lower', cmap='gray', alpha=0.25)

    # 重置并绘制每块
    for p in tangram_pieces:
        p.transformed_vertices = p.original_vertices.copy()
    for inst in response:
        piece = _NAME_TO_PIECE.get(inst["name"])
        if piece:
            piece.set_by_instruction(inst)
            piece.draw(ax)

    if title:
        ax.set_title(title)
    if show:
        plt.show()
    return response, fig, ax

def approx_iou_with_outline(response, outline_png_path, raster=160, thresh=0.5, flip_y: bool=False):
    """把拼好的联合图形与 outline PNG 做近似 IoU（栅格）。"""
    try:
        from shapely.geometry import Polygon as ShapelyPolygon
        from shapely.ops import unary_union
    except Exception:
        return None

    polys = []
    # 可选Y轴翻转：在生成多边形前进行
    if flip_y:
        _resp = []
        for r in response:
            if isinstance(r.get("pos"), (list, tuple)) and len(r["pos"]) == 2:
                x, y = r["pos"]
                r = dict(r)
                r["pos"] = [float(x), 10.0 - float(y)]
            _resp.append(r)
        response = _resp

    for inst in response:
        piece = _NAME_TO_PIECE.get(inst["name"])  # 使用模板
        if not piece:
            continue
        piece.set_by_instruction(inst)
        poly = ShapelyPolygon(piece.transformed_vertices)
        if poly.is_valid:
            polys.append(poly)
    if not polys:
        return 0.0
    union_poly = unary_union(polys)

    # 生成预测掩膜
    xs = np.linspace(0, 10, raster, endpoint=False) + 10/(2*raster)
    ys = np.linspace(0, 10, raster, endpoint=False) + 10/(2*raster)
    XX, YY = np.meshgrid(xs, ys)
    pts = np.stack([XX.ravel(), YY.ravel()], axis=1)
    from shapely.geometry import Point
    pred = np.array([union_poly.contains(Point(x, y)) for x, y in pts], dtype=bool).reshape(raster, raster)

    # 读取 GT 掩膜
    img = mpimg.imread(outline_png_path)
    if img.ndim == 3:
        img = img[..., :3].mean(axis=2)
    # PNG 是黑形白底，取“暗色”为前景
    gt = (img <= thresh)
    gt = np.flipud(gt)  # 原点置于左下
    if gt.shape != pred.shape:
        gt_img = Image.fromarray((gt.astype(np.uint8) * 255)).resize((raster, raster), Image.NEAREST)
        gt = np.array(gt_img) > 127

    inter = (pred & gt).sum()
    union = (pred | gt).sum()
    return float(inter) / float(union + 1e-9)

def draw_from_qwen_response(response, flip_y: bool=False):
    """从大模型输出的response中绘图"""
    fig, ax = plt.subplots()
    ax.set_aspect('equal')
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    ax.set_xticks(np.arange(0, 11, 1))
    ax.set_yticks(np.arange(0, 11, 1))
    ax.grid(True, which='both', linestyle='--', linewidth=0.5, color='gray', alpha=0.5)

    if flip_y:
        _resp = []
        for r in response:
            if isinstance(r.get("pos"), (list, tuple)) and len(r["pos"]) == 2:
                x, y = r["pos"]
                r = dict(r)
                r["pos"] = [float(x), 10.0 - float(y)]
            _resp.append(r)
        response = _resp

    for piece in tangram_pieces:
        inst = next((item for item in response if item["name"] == piece.name), None)
        if inst:
            piece.set_by_instruction(inst)
            piece.draw(ax)

    plt.show()


# === Overlay helpers ===
def draw_overlay_jsons(gt_json_path: str, pred_json_path: str, outline_png_path: str | None = None, flip_y: bool=False, title: str | None=None, show: bool=True):
    """
    在同一张坐标轴上叠加绘制 GT JSON（实心）与 Pred JSON（空心边框）。
    可选叠加 outline_png 作为底图；仅当与像素图对齐时才建议 flip_y=True。
    返回 (resp_gt, resp_pred, fig, ax)
    """
    import json, os
    # 先渲染 GT（不翻转 JSON；如需对齐底图 PNG，会在下方单独翻转 PNG）
    resp_gt, fig, ax = draw_from_annotation_json(
        gt_json_path,
        outline_png_path=None,  # 不在这里叠底图，避免把 flip_y 作用到 JSON 上
        title=(title or os.path.basename(pred_json_path)),
        flip_y=False,
        show=False
    )

    # 如提供了底图 PNG，则叠加并按需上下翻转
    if outline_png_path and os.path.exists(outline_png_path):
        png_img = mpimg.imread(outline_png_path)
        if png_img.ndim == 3:
            png_img = png_img[..., :3].mean(axis=2)
        if flip_y:
            import numpy as _np
            png_img = _np.flipud(png_img)
        ax.imshow(png_img, extent=[0, 10, 0, 10], origin='lower', cmap='gray', alpha=0.25)
    # 叠加 Pred（空心绘制）
    with open(pred_json_path, "r", encoding="utf-8") as f:
        ann = json.load(f)
    resp_pred = _coerce_to_response(ann)

    # 如需 flip_y，仅对“输入坐标”翻转，再绘制
    if flip_y:
        _resp = []
        for r in resp_pred:
            if isinstance(r.get("pos"), (list, tuple)) and len(r["pos"]) == 2:
                x, y = r["pos"]
                r = dict(r)
                r["pos"] = [float(x), 10.0 - float(y)]
            _resp.append(r)
        resp_pred = _resp

    for inst in resp_pred:
        piece = _NAME_TO_PIECE.get(inst["name"])
        if not piece:
            continue
        piece.set_by_instruction(inst)
        ax.add_patch(Polygon(piece.transformed_vertices, closed=True, facecolor="none", edgecolor="0.2", linewidth=2.0, alpha=0.9))

    if title:
        ax.set_title(title)
    if show:
        plt.show()
    return resp_gt, resp_pred, fig, ax


def save_overlay_jsons(gt_json_path: str, pred_json_path: str, out_png: str, outline_png_path: str | None = None, flip_y: bool=False, title: str | None=None):
    """便捷函数：叠加绘制并保存为 PNG（不弹窗）。"""
    _, _, fig, _ = draw_overlay_jsons(gt_json_path, pred_json_path, outline_png_path=outline_png_path, flip_y=flip_y, title=title, show=False)
    os.makedirs(os.path.dirname(out_png), exist_ok=True)
    fig.savefig(out_png, dpi=220, bbox_inches="tight")
    plt.close(fig)




try:
    from shapely.geometry import Polygon as ShapelyPolygon
    from shapely.ops import unary_union
    _HAS_SHAPELY = True
except Exception as _e:
    _HAS_SHAPELY = False
    _SHAPELY_IMPORT_ERROR = _e


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ann", type=str, default=os.environ.get("ANN_JSON"), help="annotations JSON 路径")
    parser.add_argument("--outline", type=str, default=os.environ.get("OUTLINE_PNG"), help="outline PNG 路径(可选)")
    parser.add_argument("--under_png", type=str, default=None, help="outline PNG 路径的别名(可选, 兼容参数)")
    parser.add_argument("--save_png", type=str, default=None, help="保存渲染后的图像到此路径")
    # --- Overlay arguments ---
    parser.add_argument("--overlay_gt_json", type=str, default=None, help="若与 --overlay_pred_json 同时提供，则在同一坐标轴叠加 GT 与 Pred")
    parser.add_argument("--overlay_pred_json", type=str, default=None, help="与 --overlay_gt_json 搭配使用的 Pred JSON 路径")
    parser.add_argument("--map_to_axes", action="store_true", help="兼容参数, 在本脚本中无实际作用")
    parser.add_argument("--anchor", type=str, choices=["corner","centroid"], default=None, help="兼容参数, 在本脚本中无实际作用")
    parser.add_argument("--size_into_scale", type=str, choices=["keep","fold","auto"], default=None, help="兼容参数, 在本脚本中无实际作用")
    parser.add_argument("--no_show", action="store_true", help="不弹窗，仅计算 IoU")
    parser.add_argument("--flip_y", action="store_true", help="仅在与像素 PNG 叠加时使用：对输入JSON做 y -> 10 - y，使之与像素坐标(原点左上)对齐")
    args = parser.parse_args()

    # Determine outline path for overlay and IoU
    outline_path = args.outline if args.outline else args.under_png

    # ===== Overlay GT & Pred JSON 模式 =====
    if args.overlay_gt_json and args.overlay_pred_json:
        # 注意：overlay 是 JSON vs JSON，同一坐标轴；只有当 outline_png 与像素坐标对齐时才建议 flip_y
        save_path = args.save_png or (os.path.splitext(args.overlay_pred_json)[0] + "_overlay.png")
        save_overlay_jsons(
            args.overlay_gt_json,
            args.overlay_pred_json,
            out_png=save_path,
            outline_png_path=outline_path,
            flip_y=args.flip_y,
            title=os.path.basename(args.overlay_pred_json),
        )
        print(f"[OVERLAY] Saved -> {save_path}")
        # 如果仅执行叠加，直接退出
        exit(0)

    if args.ann:
        resp, fig, ax = draw_from_annotation_json(
            args.ann,
            outline_png_path=outline_path,
            title=os.path.basename(args.ann),
            flip_y=args.flip_y,
            show=not args.no_show
        )
        if outline_path:
            iou = approx_iou_with_outline(resp, outline_path, flip_y=args.flip_y)
            if iou is not None:
                print(f"[IoU] {os.path.basename(args.ann)} vs {os.path.basename(outline_path)} -> {iou:.3f}")
        if args.save_png:
            fig.savefig(args.save_png, dpi=200, bbox_inches="tight")
            if args.no_show:
                plt.close(fig)
        elif args.no_show:
            plt.close('all')
    else:
        # No inputs provided; render a tiny demo instead of crashing.
        demo_response = [
            {"name": "T3", "pos": [5.0, 5.0], "angle": -135.0, "flip": False, "scale": 1.0}
        ]
        draw_from_qwen_response(demo_response)


# ==== Adapter API for ICL_qwen_onepiece.py ====
# 让 ICL 脚本可以调用本模块把单个/多个拼片渲染成二值 mask。
from typing import Dict, List, Tuple, Union

def _unit_poly_from_piece(piece: Dict) -> List[Tuple[float, float]]:
    """
    返回以网格单位表示的基元多边形（模板原点在一角）。
    triangle: 直角等腰三角形；square: 单位正方形；parallelogram: 1x1 斜四边形。
    size 只影响模板边长，真正尺寸由 scale 控制（避免二次缩放）。
    """
    t = str(piece.get("type", "triangle")).lower()
    size = str(piece.get("size", "na")).lower()
    if t == "triangle":
        base = 1.0
        if size == "medium":
            base = SQRT2 / 2.0
        elif size == "small":
            base = 0.5
        return [(0.0, 0.0), (base, 0.0), (0.0, base)]
    elif t == "square":
        return [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]
    else:  # parallelogram
        return [(0.0, 0.0), (1.0, 0.0), (1.5, 0.5), (0.5, 0.5)]

def _grid_frame_params(use_bbox: bool, W: int, H: int, bbox: Tuple[int,int,int,int] | None):
    """
    计算 0..10 网格到像素的映射：返回 (px, offx, offy)。
    use_bbox=True 时用前景 bbox 的短边做单位；否则用整图短边。
    """
    if use_bbox and bbox is not None:
        x0, y0, x1, y1 = bbox
        bw = max(1, int(x1 - x0))
        bh = max(1, int(y1 - y0))
        px = min(bw, bh) / 10.0
        return float(px), float(x0), float(y0)
    px = min(W, H) / 10.0
    return float(px), 0.0, 0.0

def _transform_poly(poly, pos, angle_deg, flip, scale, px, offx, offy, anchor="corner"):
    """把网格多边形做翻转/缩放/旋转/平移并映射到像素坐标。"""
    import math
    ax = math.radians(float(angle_deg))
    ca, sa = math.cos(ax), math.sin(ax)
    cx = cy = 0.0
    if anchor == "centroid" and len(poly) > 0:
        cx = sum(x for x,_ in poly) / len(poly)
        cy = sum(y for _,y in poly) / len(poly)
    out = []
    for (x, y) in poly:
        x -= cx; y -= cy
        if flip:
            x = -x
        x *= float(scale); y *= float(scale)
        xr = x*ca - y*sa
        yr = x*sa + y*ca
        X = (xr + float(pos[0])) * px + offx
        Y = (yr + float(pos[1])) * px + offy
        out.append((float(X), float(Y)))
    return out

def piece_to_polys(piece: Dict, W: int, H: int,
                   bbox: Tuple[int,int,int,int] | None = None,
                   anchor: str = "corner",
                   map_to_bbox: bool | None = None):
    """单片 → 像素多边形列表。"""
    pos   = piece.get("pos", [5.0, 5.0])
    ang   = float(piece.get("angle", 0.0))
    flip  = bool(piece.get("flip", False))
    scale = float(piece.get("scale", 1.0))
    poly  = _unit_poly_from_piece(piece)
    px, offx, offy = _grid_frame_params(bool(map_to_bbox), int(W), int(H), bbox)
    return [_transform_poly(poly, pos, ang, flip, scale, px, offx, offy, anchor=anchor)]

def pieces_to_polys(pieces, W: int, H: int,
                    bbox: Tuple[int,int,int,int] | None = None,
                    anchor: str = "corner",
                    map_to_bbox: bool | None = None):
    """多片版本。"""
    polys = []
    for p in (pieces if isinstance(pieces, (list, tuple)) else [pieces]):
        polys.extend(piece_to_polys(p, W, H, bbox=bbox, anchor=anchor, map_to_bbox=map_to_bbox))
    return polys

def piece_to_mask(piece: Dict, W: int, H: int,
                  bbox: Tuple[int,int,int,int] | None = None,
                  anchor: str = "corner",
                  map_to_bbox: bool | None = None):
    """单片 → 二值 mask（uint8，1 为前景）。"""
    from PIL import Image, ImageDraw
    import numpy as _np
    polys = pieces_to_polys([piece], W, H, bbox=bbox, anchor=anchor, map_to_bbox=map_to_bbox)
    img = Image.new("L", (int(W), int(H)), 0)
    drw = ImageDraw.Draw(img)
    for poly in polys:
        drw.polygon(poly, fill=1)
    return _np.array(img, dtype=_np.uint8)

def render_mask(pieces: Union[Dict, List[Dict]], W: int, H: int,
                bbox: Tuple[int,int,int,int] | None = None,
                anchor: str = "corner",
                map_to_bbox: bool | None = None):
    """公共入口：支持单片或多片。"""
    from PIL import Image, ImageDraw
    import numpy as _np
    if isinstance(pieces, dict):
        return piece_to_mask(pieces, W, H, bbox=bbox, anchor=anchor, map_to_bbox=map_to_bbox)
    img = Image.new("L", (int(W), int(H)), 0)
    drw = ImageDraw.Draw(img)
    for p in pieces:
        for poly in pieces_to_polys([p], W, H, bbox=bbox, anchor=anchor, map_to_bbox=map_to_bbox):
            drw.polygon(poly, fill=1)
    return _np.array(img, dtype=_np.uint8)