from graphviz import Digraph
import os
from helpers import viz_util
import json
import numpy as np
import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
from matplotlib import cm       # color map
from collections import defaultdict
import time
from helpers.gui_hub import init_main_window


_APP_GUI_READY = False
_WINS: dict[str, gui.SceneWidget] = {}     # title -> SceneWidget
# ------------------------------------------------------------------


def _ensure_window(title: str,
                   w: int = 1280, h: int = 800) -> gui.SceneWidget:
    app = gui.Application.instance
    if title in _WINS:
        return _WINS[title]

    win  = app.create_window(title, w, h)
    w3d  = gui.SceneWidget()
    w3d.scene = rendering.Open3DScene(win.renderer)
    win.add_child(w3d)
    _WINS[title] = w3d

    def _on_close():
        _WINS.pop(title, None)
        return True
    win.set_on_close(_on_close)

    return w3d


# ------------------------------------------------------------------
# visualieze incremental scene graph in 3D with Open3D GUI
# ------------------------------------------------------------------
def visualize_incremental_scene_graph_3d_gui(
        scan_id: str,
        classes,
        relationships_names,
        new_obj_ids, triples_step,
        row_global_ids, boxes_row, row_cls_ids,
        w_graph, cache,
        node_rad=0.12,
        line_col=(0, 0, 1),      
        new_col =(1., 0., 0.),   
        txt_gap=0.08,
        step_idx: int = None,    
        step_palette=None        
    ):

    if step_palette is None:
        step_palette = [
            (0.90, 0.10, 0.10),  # red
            (0.10, 0.60, 0.95),  # blue
            (0.20, 0.75, 0.25),  # green
            (1.00, 0.55, 0.00),  # orange
            (0.60, 0.30, 0.80),  # purple
            (0.15, 0.70, 0.70),  # teal
            (0.80, 0.80, 0.20),  # yellow
            (0.90, 0.30, 0.50),  # pink
            (0.35, 0.35, 0.35),  # grey
            (0.10, 0.10, 0.80),  # indigo
        ]

    to_np = lambda x: x.cpu().numpy() if hasattr(x, "cpu") else np.asarray(x)
    new_obj_ids    = to_np(new_obj_ids).astype(int).ravel().tolist()
    triples_step   = to_np(triples_step).astype(int).reshape(-1, 3)
    row_global_ids = to_np(row_global_ids).astype(int).ravel()
    boxes_row      = to_np(boxes_row).astype(np.float32).reshape(-1, 7)
    row_cls_ids    = to_np(row_cls_ids).astype(int).ravel()

    def get_label(cid: int):
        if isinstance(classes, (list, tuple)):
            return classes[cid].strip() if 0 <= cid < len(classes) else str(cid)
        if isinstance(classes, dict):
            return classes.get(cid, str(cid))
        return str(cid)

    batch_id2box   = dict(zip(row_global_ids, boxes_row))
    batch_id2label = {gid: get_label(int(cid))
                      for gid, cid in zip(row_global_ids, row_cls_ids)}

    # ---------- widget & cache ----------
    w3d = w_graph
    if cache is None:
        cache = {
            "v":   {},          # gid -> geom name
            "e":   set(),       # frozenset({a,b})
            "lbl": {},          # gid -> gui label
            "box": {},          # gid -> box
            "lab": {},          # gid -> str label
            "cam": False,
            "gid2step": {},     # gid -> step_idx
            "edge2step": {},    # frozenset({a,b}) -> step_idx
            "step2col": {},     # step_idx -> (r,g,b)
            "last_step": -1
        }
    cache["box"].update(batch_id2box)
    cache["lab"].update(batch_id2label)

    if step_idx is None:
        step_idx = cache.get("last_step", -1) + 1
    cache["last_step"] = max(cache.get("last_step", -1), step_idx)
    if step_idx not in cache["step2col"]:
        cache["step2col"][step_idx] = step_palette[step_idx % len(step_palette)]
    cur_col = cache["step2col"][step_idx]

    label_scale: float = 2.0
    label_color=(0.0, 0.0, 0.0)

    def _style_label(lbl):
        try:
            lbl.color = gui.Color(float(label_color[0]), float(label_color[1]), float(label_color[2]))
        except Exception:
            pass
        if hasattr(lbl, "scale"):
            try:
                lbl.scale = float(label_scale)
            except Exception:
                pass
        else:
            try:
                app = gui.Application.instance
                theme = app.theme
                base = getattr(theme, "font_size", 16)
                theme.font_size = int(base * float(label_scale))
                app.theme = theme
            except Exception:
                pass


    for gid in new_obj_ids:
        if gid in cache["v"] or gid not in cache["box"]:
            cache["gid2step"].setdefault(gid, step_idx)
            continue
        _, _, _, cx, cy, cz, _ = cache["box"][gid]
        center = np.array([cx, cy, cz], np.float32)

        sph = o3d.geometry.TriangleMesh.create_sphere(node_rad)
        sph.translate(center)
        sph.paint_uniform_color(cur_col)   

        mesh_name = f"v_{gid}"
        mat = rendering.MaterialRecord(); mat.shader = "defaultLit"
        w3d.scene.add_geometry(mesh_name, sph, mat)
        cache["v"][gid] = mesh_name
        cache["gid2step"][gid] = step_idx

        # cache["lbl"][gid] = w3d.add_3d_label(
        #     center + np.array([0., 0.04, 0.], np.float32),
        #     cache["lab"][gid])
        lbl = w3d.add_3d_label(
            center + np.array([0., 0.04, 0.], np.float32),
            cache["lab"][gid])
        _style_label(lbl)   
        cache["lbl"][gid] = lbl

    edge_dict = defaultdict(list)        # (small,big) -> [rid...]
    for s, rid, o  in triples_step:
        s, rid, o = int(s), int(rid), int(o)
        if s == o:
            continue
        a, b = sorted((s, o))
        edge_dict[(a, b)].append(rid)

    up_vec = np.array([0., 1., 1.], np.float32)

    for (a, b), rid_list in edge_dict.items():
        if a not in cache["box"] or b not in cache["box"]:
            continue
        pa, pb = cache["box"][a][3:6], cache["box"][b][3:6]

        key = frozenset({a, b})
        if key not in cache["e"] and np.linalg.norm(pa - pb) > 1e-6:
            cache["edge2step"][key] = step_idx
            edge_col = cache["step2col"][cache["edge2step"][key]]

            ls = o3d.geometry.LineSet(
                points=o3d.utility.Vector3dVector(np.vstack([pa, pb])),
                lines=o3d.utility.Vector2iVector([[0, 1]]))
            ls.colors = o3d.utility.Vector3dVector([edge_col])   
            mat_l = rendering.MaterialRecord(); mat_l.shader = "unlitLine"
            mat_l.line_width = 3.0
            w3d.scene.add_geometry(f"e_{a}_{b}", ls, mat_l)
            cache["e"].add(key)

        main_rid = rid_list[0]
        edge_dir = pb - pa; ed_len = np.linalg.norm(edge_dir)
        if ed_len < 1e-6:
            continue
        edge_dir /= ed_len

        perp = np.cross(edge_dir, up_vec)
        if np.linalg.norm(perp) < 1e-3:
            perp = np.cross(edge_dir, np.array([1., 0., 0.], np.float32))
        perp /= np.linalg.norm(perp)

        anchor = pa + edge_dir * (ed_len / 2)
        pos    = anchor + up_vec * 0.08 + perp * txt_gap

        txt = relationships_names[main_rid - 1].strip()
        #w3d.add_3d_label(pos, txt)
        # lbl_edge = w3d.add_3d_label(pos, txt)
        # _style_label(lbl_edge)   

    if not cache["cam"]:
        bbox = w3d.scene.bounding_box
        w3d.setup_camera(60., bbox, bbox.get_center())
        cache["cam"] = True

    gui.Application.instance.run_one_tick()
    return w3d, cache

# def visualize_incremental_scene_graph_3d_gui(
#         scan_id: str,
#         classes,
#         relationships_names,
#         new_obj_ids, triples_step,
#         row_global_ids, boxes_row, row_cls_ids,
#         w_graph, cache,
#         node_rad=0.12,
#         line_col=(0, 0, 1),
#         new_col =(1., 0., 0.),
#         txt_gap=0.05):

#     # ---------------- 0. GUI init -----------------
#     # global _APP_GUI_READY
#     # app = gui.Application.instance
#     # if not _APP_GUI_READY:
#     #     app.initialize()
#     #     _APP_GUI_READY = True

#     # ---------------- 1. numpy  -----------------
#     to_np = lambda x: x.cpu().numpy() if hasattr(x, "cpu") else np.asarray(x)
#     new_obj_ids    = to_np(new_obj_ids).astype(int).ravel().tolist()
#     triples_step   = to_np(triples_step).astype(int).reshape(-1, 3)
#     row_global_ids = to_np(row_global_ids).astype(int).ravel()
#     boxes_row      = to_np(boxes_row).astype(np.float32).reshape(-1, 7)
#     row_cls_ids    = to_np(row_cls_ids).astype(int).ravel()

#     # ---------------- 2. label lookup --------------
#     def get_label(cid: int):
#         if isinstance(classes, (list, tuple)):
#             return classes[cid].strip() if 0 <= cid < len(classes) else str(cid)
#         if isinstance(classes, dict):
#             return classes.get(cid, str(cid))
#         return str(cid)

#     batch_id2box   = dict(zip(row_global_ids, boxes_row))
#     batch_id2label = {gid: get_label(int(cid))
#                       for gid, cid in zip(row_global_ids, row_cls_ids)}

#     # ---------------- 3. widget & cache ------------
#     #w3d = w3d_vis or _ensure_window(f"Incremental SG – {scan_id}")

#     w3d= w_graph

#     if cache is None:
#         cache = {
#             "v":   {},          # gid -> geom name
#             "e":   set(),       # frozenset({a,b}) 
#             "lbl": {},          # gid -> gui label
#             "box": {},          # gid -> box
#             "lab": {},          # gid -> str label
#             "cam": False
#         }
#     cache["box"].update(batch_id2box)
#     cache["lab"].update(batch_id2label)

#     # ---------------- 4. new nodes -----------------
#     for gid in new_obj_ids:
#         if gid in cache["v"] or gid not in cache["box"]:
#             continue
#         *_, cx, cy, cz, _ = cache["box"][gid]
#         center = np.array([cx, cy, cz], np.float32)

#         sph = o3d.geometry.TriangleMesh.create_sphere(node_rad)
#         sph.translate(center)
#         sph.paint_uniform_color(new_col)

#         mesh_name = f"v_{gid}"
#         mat = rendering.MaterialRecord(); mat.shader = "defaultLit"
#         w3d.scene.add_geometry(mesh_name, sph, mat)
#         cache["v"][gid] = mesh_name

#         cache["lbl"][gid] = w3d.add_3d_label(
#             center + np.array([0., 0.04, 0.], np.float32),
#             cache["lab"][gid])

#     # ---------------- 5. gather edges --------------
#     #print('triples_step: ', triples_step)
#     edge_dict = defaultdict(list)        # (small,big) -> [rid...]
#     for s, rid, o  in triples_step:
#         s, rid, o = int(s), int(rid), int(o)
#         if s == o:
#             continue
#         a, b = sorted((s, o))            
#         edge_dict[(a, b)].append(rid)

#     # ---------------- 6. draw ---------------------
#     up_vec = np.array([0., 1., 1.], np.float32)

#     for (a, b), rid_list in edge_dict.items():
#         if a not in cache["box"] or b not in cache["box"]:
#             continue
#         pa, pb = cache["box"][a][3:6], cache["box"][b][3:6]

#        
#         key = frozenset({a, b})
#         if key not in cache["e"] and np.linalg.norm(pa - pb) > 1e-6:
#             ls = o3d.geometry.LineSet(
#                     points=o3d.utility.Vector3dVector(np.vstack([pa, pb])),
#                     lines=o3d.utility.Vector2iVector([[0, 1]]))
#             ls.colors = o3d.utility.Vector3dVector([line_col])
#             mat_l = rendering.MaterialRecord(); mat_l.shader = "unlitLine"
#             w3d.scene.add_geometry(f"e_{a}_{b}", ls, mat_l)
#             cache["e"].add(key)

#         # 6.2 label 
#         main_rid = rid_list[0]                 
#         edge_dir = pb - pa
#         ed_len   = np.linalg.norm(edge_dir)
#         if ed_len < 1e-6:
#             continue
#         edge_dir /= ed_len

#         perp = np.cross(edge_dir, up_vec)
#         if np.linalg.norm(perp) < 1e-3:
#             perp = np.cross(edge_dir, np.array([1., 0., 0.], np.float32))
#         perp /= np.linalg.norm(perp)

#         anchor = pa + edge_dir * (ed_len / 2)               
#         pos    = anchor + up_vec * 0.08 + perp * txt_gap

#         txt = relationships_names[main_rid - 1].strip()    
#         w3d.add_3d_label(pos, txt)
#         #w3d.add_3d_label(pos, txt)

#     # ---------------- 7. camera -------------------
#     if not cache["cam"]:
#         bbox = w3d.scene.bounding_box
#         w3d.setup_camera(60., bbox, bbox.get_center())
#         cache["cam"] = True

#     # ---------------- 8. refresh ------------------
#     gui.Application.instance.run_one_tick()  
#     #time.sleep(1) 
#     return w3d, cache


def visualize_scene_graph_3d(two_d_graph, boxes, relationships,
                             rel_filter_in = (), rel_filter_out = (),
                             obj_ids       = (),
                             win_name      = "Incremental 3D-Scene-Graph"):
    want = set(obj_ids) if obj_ids else None
    centers, labels = {}, {}

    for obj in two_d_graph["objects"]:
        oid = int(obj["id"])
        if want and oid not in want:
            continue
        cx, cy, cz = boxes[oid]["param7"][3:6]
        centers[oid] = np.array([cx, cy, cz], np.float32)
        labels[oid]  = obj["label"]

    geoms, class2c = [], {}
    cmap = cm.get_cmap("tab20")

    for oid, xyz in centers.items():
        color = class2c.setdefault(labels[oid],
                                   cmap(len(class2c) / 20)[:3])

        sph = o3d.geometry.TriangleMesh.create_sphere(0.12)
        sph.translate(xyz); sph.paint_uniform_color(color)
        geoms.append(sph)

        if hasattr(o3d.geometry.TriangleMesh, "create_text"):
            txt = o3d.geometry.TriangleMesh.create_text(
                text        = labels[oid],
                depth       = 0.02,
                font_size   = 20,
                density     = 5.0,
                font        = "Liberation Sans")
            txt.translate(xyz + np.array([0, 0.18, 0]))
            txt.paint_uniform_color([0,0,0])
            geoms.append(txt)
        else:
            # headless wheel
            print(f"[label] {labels[oid]:>15} @ {xyz.round(2)}")

    lines, cols, pts = [], [], []
    for s, o, rid, _ in two_d_graph["relationships"]:
        if s not in centers or o not in centers:
            continue
        rname = relationships[rid-1].rstrip()
        if (rel_filter_in and rname not in rel_filter_in) or \
           (rname in rel_filter_out):
            continue
        idx0, idx1 = len(pts), len(pts)+1
        pts.extend([centers[s], centers[o]])
        lines.append([idx0, idx1])
        cols.append([1,0,0] if rname in rel_filter_in else [0.6]*3)

    if lines:
        ls = o3d.geometry.LineSet(
            points = o3d.utility.Vector3dVector(np.asarray(pts)),
            lines  = o3d.utility.Vector2iVector(np.asarray(lines)))
        ls.colors = o3d.utility.Vector3dVector(np.asarray(cols))
        geoms.append(ls)

    o3d.visualization.draw_geometries(
        geoms,
        window_name       = win_name,
        mesh_show_back_face=True)

def visualize_scene_graph_2d(rel, relationships, rel_filter_in = [], rel_filter_out = [], obj_ids = [], title ="", scan_id="",
													outfolder="./vis_graphs/"):
	g = Digraph(comment='Scene Graph' + title, format='png')

	for (i,obj) in enumerate(rel["objects"]):
		if (len(obj_ids) == 0) or (int(obj['id']) in obj_ids):
			if "node_mask" in rel.keys() and rel["node_mask"][i] == 0:
				g.node(str(obj['id']), obj["label"], fontname='helvetica', color='lightgoldenrod1', fontcolor='red')
			else:
				g.node(str(obj['id']), obj["label"], fontname='helvetica', color='lightgoldenrod1', style='filled')
	if "edge_mask" in rel.keys():
		edge_mask = rel["edge_mask"]
	else:
		edge_mask = None
	draw_edges(g, rel["relationships"], relationships, rel_filter_in, rel_filter_out, obj_ids, edge_mask)
	g.render(outfolder + scan_id)


def draw_edges(g, graph_relationships, relationships, rel_filter_in, rel_filter_out, obj_ids, edge_mask=None):
	edges = {}
	if edge_mask is not None:
		joined_edge_mask = {}
	for (i, rel) in enumerate(graph_relationships):
		rel_text = relationships[rel[2]-1]
		if (len(rel_filter_in) == 0 or (rel_text.rstrip() in rel_filter_in)) and not rel_text.rstrip() in rel_filter_out:
			if (len(obj_ids) == 0) or ((rel[1] in obj_ids) and (rel[0] in obj_ids)):
				index = str(rel[0]) + "_" + str(rel[1])
				if index not in edges:
					edges[index] = []
					if edge_mask is not None:
						joined_edge_mask[index] = []
				edges[index].append(rel[3])
				if edge_mask is not None:
					joined_edge_mask[index].append(edge_mask[i])

	for (i,edge) in enumerate(edges):
		edge_obj_sub = edge.split("_")
		rels = ', '.join(edges[edge])
		if edge_mask is not None and 0 in joined_edge_mask[edge]:
			g.edge(str(edge_obj_sub[0]), str(edge_obj_sub[1]), label=rels, color='red', style='dotted')
		else:
			g.edge(str(edge_obj_sub[0]), str(edge_obj_sub[1]), label=rels, color='grey')

	# """
    # scene_graphs : dict
    #     { "<scan-id>": {
    #           "objects": [ {"id": <int>, "label": <str>} , ... ],
    #           "relationships": [ [sub, obj, rel_id, rel_txt], ... ]
    #       }, ...
    #     }
    #"""
def read_3D_graph_relationship_json(rel_json_file: str,
                                    box_json_file: str):
    
    with open(box_json_file, "r") as f:
        box_data = json.load(f)   # {scan_id: {inst_id: {...}, ...}}

    graphs, tight_boxes = {}, {}

    with open(rel_json_file, "r") as f:
        data = json.load(f)

    for scan in data["scans"]:
        scan_id = scan["scan"]

        # 1.1 objects
        obj_list = [{"id": int(k), "label": v}
                    for k, v in scan["objects"].items()]

        # 1.2 relationships  
        rel_list = [[sub, obj, rid - 1, txt]
                    for sub, obj, rid, txt in scan["relationships"]]

        graphs[scan_id] = {
            "objects":       obj_list,
            "relationships": rel_list
        }

        # 1.3 boxes
        box_dict     = defaultdict(dict)
        scene_center = box_data[scan_id]["scene_center"]

        for inst_id_str, info in box_data[scan_id].items():
            if inst_id_str == "scene_center":
                continue
            iid = int(inst_id_str)
            box_dict[iid]["param7"]     = info["param7"]
            box_dict[iid]["scale"]      = info["scale"]
            box_dict[iid]["model_path"] = info.get("model_path")

        box_dict["scene_center"] = scene_center
        tight_boxes[scan_id]     = box_dict

    return graphs, tight_boxes

def incremental_run(scan_id, classes,
                    new_obj_ids, triples_step_k,
                    row_global_ids, boxes_row,row_cls_ids, 
                    data_path,
                    w_graph=None, cache=None):

    relationships = viz_util.read_relationships(
        os.path.join(data_path, "relationships.txt"))

    w3d, cache = visualize_incremental_scene_graph_3d_gui(
        scan_id,
        classes,
        relationships,
        new_obj_ids,
        triples_step_k,
        row_global_ids,
        boxes_row,
        row_cls_ids, 
        w_graph     = w_graph,     
        cache   = cache)         
    return w3d, cache

	
def run( scan_id='', split=None, room_type=None, data_path='./GT'):


	if split == 'train_scans': # training set
		scene_graph_json_file = os.path.join(data_path, 'relationships_{}_trainval.json'.format(room_type))
	else: # test set
		scene_graph_json_file = os.path.join(data_path, 'relationships_{}_test.json'.format(room_type))
	if split == 'train_scans': # training set
		box_json_file = os.path.join(data_path, 'obj_boxes_{}_trainval.json'.format(room_type))
	else: # test set
		box_json_file = os.path.join(data_path, 'obj_boxes_{}_test.json'.format(room_type))

	#print("visualize scan_id: ", scan_id)
	#print("Loading scene graph from: ", scene_graph_json_file)

	relationships = viz_util.read_relationships(os.path.join(data_path, "relationships.txt"))

	#graph = load_semantic_scene_graphs(scene_graph_json_file)
	two_d_graph, tight_boxes=read_3D_graph_relationship_json( scene_graph_json_file, box_json_file)



	filter_dict_in = [] 
	filter_dict_out = [] # ["left", "right", "behind", "front", "same as", "same symmetry as", "bigger than", "lower than", "higher than", "close by"]

	#visualize_scene_graph_2d(rel[scan_id], relationships, filter_dict_in, filter_dict_out, [], "v1", scan_id=scan_id,outfolder=outfolder)
	visualize_scene_graph_3d(
    two_d_graph           = two_d_graph[scan_id],
    boxes         = tight_boxes[scan_id],
    relationships = relationships,
    rel_filter_in = ("left", "right", "behind", "front", "same as", "same symmetry as", "bigger than", "lower than", "higher than", "close by"),     # 只想高亮左右关系；可设 ()
    rel_filter_out= ("in",),               
    obj_ids       = ()                   )  



