import argparse
import numpy as np
import pyvista as pv
import dgl

from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.BRep import BRep_Tool
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_FACE
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
from OCC.Core.gp import gp_Pnt
from OCC.Core.gp import gp_Vec
from OCC.Core.TopoDS import topods_Face
from OCC.Core.TopLoc import TopLoc_Location


def load_step_solid(path):
    """加载 STEP 文件，返回一个 TopoDS_Shape"""
    reader = STEPControl_Reader()
    status = reader.ReadFile(path)
    if status != 1:
        raise RuntimeError(f"Failed to read STEP file: {path}")
    reader.TransferRoots()
    shape = reader.OneShape()
    return shape


def mesh_shape(shape, linear_deflection=0.1):
    """
    对整个 shape 进行增量网格化，返回 (vertices, triangles)：
      - vertices: (N, 3) numpy 数组
      - triangles: (M, 3) int 索引数组
    """
    # 第一次：网格化每个面
    mesh = BRepMesh_IncrementalMesh(shape, linear_deflection)
    mesh.Perform()
    # 收集全局顶点和三角面
    all_pts = []
    all_tris = []
    idx_offset = 0

    exp = TopExp_Explorer(shape, TopAbs_FACE)
    while exp.More():
        face = topods_Face(exp.Current())
        loc = TopLoc_Location()
        tri: "Poly_Triangulation" = BRep_Tool.Triangulation(face, loc)
        if tri is not None:
            # 原始节点
            nodes = tri.Nodes()
            pts = np.array([[nodes.Value(i).X(),
                             nodes.Value(i).Y(),
                             nodes.Value(i).Z()]
                            for i in range(1, tri.NbNodes()+1)], dtype=np.float64)
            # 将节点按 face 的 Loc 转换（若 Loc 不是单位矩阵）
            trsf = loc.Transformation()
            for i in range(pts.shape[0]):
                p = gp_Pnt(pts[i, 0], pts[i, 1], pts[i, 2])
                p.Transform(trsf)
                pts[i, :] = [p.X(), p.Y(), p.Z()]

            # 三角面索引（注意：OpenCASCADE 从 1 开始）
            tris = np.array([[tri.Triangle(i).Get()[j] - 1
                              for j in range(3)]
                             for i in range(1, tri.NbTriangles()+1)],
                            dtype=np.int32)

            all_pts.append(pts)
            all_tris.append(tris + idx_offset)
            idx_offset += pts.shape[0]
        exp.Next()

    if not all_pts:
        raise RuntimeError("No mesh generated from shape")
    vertices = np.vstack(all_pts)
    triangles = np.vstack(all_tris)
    return vertices, triangles


def sample_face_uv(face, nu=20, nv=20):
    """
    在单个面上等距采样 (u,v)，返回两个数组 (pts, normals)
      - pts: (nu*nv, 3)
      - normals: (nu*nv, 3)
    """
    adaptor = BRepAdaptor_Surface(face, True)
    surf = adaptor.Surface().Surface()
    umin, umax = adaptor.FirstUParameter(), adaptor.LastUParameter()
    vmin, vmax = adaptor.FirstVParameter(), adaptor.LastVParameter()
    us = np.linspace(umin, umax, nu)
    vs = np.linspace(vmin, vmax, nv)

    pts = []
    nors = []
    for u in us:
        for v in vs:
            p = gp_Pnt()
            du = gp_Vec(); dv = gp_Vec()
            surf.D1(u, v, p, du, dv)
            pt = np.array([p.X(), p.Y(), p.Z()], dtype=np.float64)
            n = np.cross(
                np.array([du.X(), du.Y(), du.Z()], dtype=np.float64),
                np.array([dv.X(), dv.Y(), dv.Z()], dtype=np.float64),
            )
            norm = np.linalg.norm(n)
            if norm > 1e-8:
                n = n / norm
            else:
                n = np.array([0.0, 0.0, 0.0], dtype=np.float64)
            pts.append(pt)
            nors.append(n)
    return np.array(pts), np.array(nors)

def main():
    parser = argparse.ArgumentParser(
        description="Visualize UV grids and face adjacency graph from STEP + DGL graph"
    )
    parser.add_argument("step_file", help="Input STEP file path")
    parser.add_argument("graph_file", help="DGL graph .bin file path")
    parser.add_argument("--output_image", help="Output image file path (e.g., render.png)", default="render.png")
    args = parser.parse_args()

    print("Starting visualization script...")

    # 1. 读取并网格化 STEP
    print("Loading STEP file...")
    shape = load_step_solid(args.step_file)
    print(f"STEP file loaded. Shape type: {type(shape)}")

    print("Meshing shape...")
    vertices, triangles = mesh_shape(shape, linear_deflection=0.05)
    if vertices is not None and triangles is not None:
        print(f"Shape meshed. Vertices count: {len(vertices)}, Triangles count: {len(triangles)}")
    else:
        print("Meshing failed or produced no geometry.")
        return # Exit if meshing failed

    # 2. 用 PyVista 可视化
    print("Initializing PyVista Plotter with off_screen=True...")
    plotter = pv.Plotter(off_screen=True)
    print("Plotter initialized.")

    print("Creating PolyData for base mesh...")
    mesh_polydata = pv.PolyData(vertices, np.hstack([
        np.full((triangles.shape[0], 1), 3, dtype=np.int32),
        triangles.astype(np.int32)
    ]))
    print("PolyData created.")

    print("Adding base mesh to plotter...")
    plotter.add_mesh(mesh_polydata, opacity=0.5, color="lightgray")
    print("Base mesh added to plotter.")

    # 3. 对每个面采样 UV 网格并画出采样点和法向量
    print("Starting face iteration for UV sampling...")
    exp = TopExp_Explorer(shape, TopAbs_FACE)
    face_centers = []
    face_idx = 0
    while exp.More():
        face_idx += 1
        print(f"Processing face {face_idx}...")
        current_face = exp.Current() # Use a different variable name to avoid conflict
        
        print(f"  Sampling UV for face {face_idx}...")
        pts, nors = sample_face_uv(current_face, nu=20, nv=20) # Pass current_face
        print(f"  UVs sampled. Points: {len(pts) if pts is not None else 0}, Normals: {len(nors) if nors is not None else 0}")

        if pts is not None and len(pts) > 0:
            print(f"  Adding points for face {face_idx} to plotter...")
            plotter.add_points(pts, color="purple", point_size=5)
            print(f"  Points added for face {face_idx}.")
            
            print(f"  Adding arrows for face {face_idx} to plotter...")
            for p_idx, (p_coord, n_vec) in enumerate(zip(pts, nors)):
                # print(f"    Adding arrow {p_idx+1}/{len(pts)} for face {face_idx}") # Very verbose, uncomment if needed
                arrow = pv.Arrow(start=p_coord, direction=n_vec, scale=0.05)
                plotter.add_mesh(arrow, color="purple")
            print(f"  Arrows added for face {face_idx}.")
            
            face_centers.append(pts.mean(axis=0))
        else:
            print(f"  No points sampled for face {face_idx}, skipping plotting for this face.")

        exp.Next()
    print(f"Face iteration complete. Total faces processed: {face_idx}")

    if not face_centers:
        print("No face centers were computed. There might be an issue with face processing or the model has no faces.")
        # Decide if you want to exit or continue if no faces were processed
    else:
        face_centers = np.vstack(face_centers)
        print(f"Face centers computed: {len(face_centers)}")

    # 4. 加载 DGL 图，绘制面–面连接（以圆柱或线段）
    print("Loading DGL graph...")
    glist, _ = dgl.load_graphs(args.graph_file)
    g = glist[0]
    print(f"DGL graph loaded. Number of nodes: {g.number_of_nodes()}, Number of edges: {g.number_of_edges()}")

    if face_centers is not None and len(face_centers) > 0:
        print("Adding DGL graph edges to plotter...")
        src, dst = g.edges()
        edge_count = 0
        for u, v in zip(src.tolist(), dst.tolist()):
            edge_count +=1
            # print(f"  Adding edge {edge_count}: {u} -> {v}") # Verbose
            if u < len(face_centers) and v < len(face_centers):
                p0 = face_centers[u]
                p1 = face_centers[v]
                plotter.add_lines(np.vstack([p0, p1]), color="black", width=3)
            else:
                print(f"  Skipping edge {u}-{v} due to out-of-bounds face_centers index.")
        print(f"DGL graph edges added to plotter. Total edges processed: {edge_count}")
    else:
        print("Skipping DGL graph edge plotting as no face centers are available.")


    # 5. 保存截图而不是显示窗口
    print("Setting camera position...")
    plotter.camera_position = 'iso'
    print("Camera position set.")

    print(f"Attempting to save screenshot to {args.output_image}...")
    plotter.screenshot(args.output_image, window_size=[1024, 768])
    print(f"Render saved to {args.output_image}")

    print("Visualization script finished.")

if __name__ == "__main__":
    main()