import argparse
import numpy as np
import pyvista as pv
import random 

from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.BRep import BRep_Tool
from OCC.Core.TopExp import TopExp_Explorer
from OCC.Core.TopAbs import TopAbs_FACE, TopAbs_EDGE, TopAbs_VERTEX, TopAbs_REVERSED
from OCC.Core.BRepAdaptor import BRepAdaptor_Surface, BRepAdaptor_Curve
from OCC.Core.gp import gp_Pnt, gp_Vec, gp_Trsf, gp_TrsfForm
from OCC.Core.TopoDS import topods_Face, topods_Edge, topods_Vertex
from OCC.Core.TopLoc import TopLoc_Location
from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_Transform


def load_step_solid(path):
    reader = STEPControl_Reader()
    status = reader.ReadFile(path)
    if status != 1:
        raise RuntimeError(f"Failed to read STEP file: {path}")
    reader.TransferRoots()
    shape = reader.OneShape()
    return shape


def mesh_exploded_face(shape, linear_deflection=0.1):
    mesh = BRepMesh_IncrementalMesh(shape, linear_deflection)
    mesh.Perform()
    if not mesh.IsDone():
        return None, None

    all_pts = []
    all_tris = []

    if shape.ShapeType() == TopAbs_FACE:
        face = topods_Face(shape)
        loc = TopLoc_Location() # For BRep_Tool.Triangulation
        tri = BRep_Tool.Triangulation(face, loc)
        if tri is not None and tri.NbNodes() > 0 and tri.NbTriangles() > 0:
            nodes = tri.Nodes()
            pts_local = np.array([[nodes.Value(i).X(), nodes.Value(i).Y(), nodes.Value(i).Z()]
                                  for i in range(1, tri.NbNodes() + 1)], dtype=np.float64)
            
            face_true_loc = face.Location()
            trsf_occ = face_true_loc.Transformation()
            
            pts_world = pts_local.copy()
            if not (trsf_occ.Form() == gp_TrsfForm.gp_Identity and trsf_occ.ScaleFactor() == 1.0):
                for i_pt in range(pts_local.shape[0]):
                    p_gp = gp_Pnt(pts_local[i_pt, 0], pts_local[i_pt, 1], pts_local[i_pt, 2])
                    p_gp.Transform(trsf_occ)
                    pts_world[i_pt, :] = [p_gp.X(), p_gp.Y(), p_gp.Z()]

            tris = np.array([[tri.Triangle(i).Get()[j] - 1 for j in range(3)]
                             for i in range(1, tri.NbTriangles() + 1)], dtype=np.int32)
            all_pts.append(pts_world)
            all_tris.append(tris)
        else:
            return None, None
    else:
        return None, None

    if not all_pts:
        return None, None
    return np.vstack(all_pts), np.vstack(all_tris)


def sample_face_uv(face, nu=20, nv=20):
    adaptor = BRepAdaptor_Surface(face, True)
    surf_geom = adaptor.Surface().Surface()
    if surf_geom is None: return None, None
    umin, umax = adaptor.FirstUParameter(), adaptor.LastUParameter()
    vmin, vmax = adaptor.FirstVParameter(), adaptor.LastVParameter()
    if umin >= umax or vmin >= vmax: return None, None
    
    us = np.linspace(umin, umax, nu)
    vs = np.linspace(vmin, vmax, nv)
    pts_list, nors_list = [], []
    
    face_loc = face.Location()
    trsf = face_loc.Transformation()
    apply_trsf = not (trsf.Form() == gp_TrsfForm.gp_Identity and trsf.ScaleFactor() == 1.0)

    for u_param in us:
        for v_param in vs:
            p_loc, du_loc, dv_loc = gp_Pnt(), gp_Vec(), gp_Vec()
            try:
                surf_geom.D1(u_param, v_param, p_loc, du_loc, dv_loc)
                norm_vec_loc = np.cross([du_loc.X(), du_loc.Y(), du_loc.Z()],
                                        [dv_loc.X(), dv_loc.Y(), dv_loc.Z()])
                norm_mag = np.linalg.norm(norm_vec_loc)
                if norm_mag > 1e-8: norm_vec_loc /= norm_mag
                else: norm_vec_loc = np.array([0.,0.,1.])

                p_world, norm_vec_world_gpvec = gp_Pnt(p_loc.X(),p_loc.Y(),p_loc.Z()), gp_Vec(norm_vec_loc[0],norm_vec_loc[1],norm_vec_loc[2])
                if apply_trsf:
                    p_world.Transform(trsf)
                    norm_vec_world_gpvec.Transform(trsf) # Vector transformation handles rotation/scale

                norm_vec_world = np.array([norm_vec_world_gpvec.X(), norm_vec_world_gpvec.Y(), norm_vec_world_gpvec.Z()])
                norm_mag_world = np.linalg.norm(norm_vec_world)
                if norm_mag_world > 1e-8: norm_vec_world /= norm_mag_world
                else: norm_vec_world = np.array([0.,0.,1.])
                
                if face.Orientation() == TopAbs_REVERSED: norm_vec_world *= -1
                
                pts_list.append([p_world.X(),p_world.Y(),p_world.Z()])
                nors_list.append(norm_vec_world)
            except Exception: continue
    if not pts_list: return None, None
    return np.array(pts_list), np.array(nors_list)


def sample_edge_curve(edge, n_points=20):
    curve3d, u_min, u_max = BRep_Tool.Curve(edge) # Curve in edge's local system
    if curve3d is None: return None, None

    edge_loc = edge.Location()
    trsf = edge_loc.Transformation()
    apply_trsf = not (trsf.Form() == gp_TrsfForm.gp_Identity and trsf.ScaleFactor() == 1.0)

    sampled_points_local_list = []
    if u_min < u_max:
        for u_param in np.linspace(u_min, u_max, n_points):
            p_loc = curve3d.Value(u_param)
            sampled_points_local_list.append([p_loc.X(), p_loc.Y(), p_loc.Z()])
    
    if not sampled_points_local_list: # Handle zero-length or problematic curves
        # If curve is degenerate, try to use vertices if available
        # This might happen if u_min == u_max.
        # For now, if no points from curve, we can't really sample it.
        # A more robust solution might be needed for all edge cases.
        if n_points <= 0 : return None, None # No points requested
        # If u_min == u_max, linspace might return 1 point or an error depending on n_points
        # If n_points is 1, linspace(u_min, u_max, 1) = [u_min]
        if u_min == u_max and n_points > 0:
             p_loc = curve3d.Value(u_min)
             sampled_points_local_list.append([p_loc.X(), p_loc.Y(), p_loc.Z()])
        elif not sampled_points_local_list: # Still no points
            return None, None


    sampled_points_local_np = np.array(sampled_points_local_list, dtype=np.float64)
    sampled_points_world_np = sampled_points_local_np.copy()
    if apply_trsf:
        for i in range(sampled_points_world_np.shape[0]):
            p = gp_Pnt(sampled_points_local_np[i,0], sampled_points_local_np[i,1], sampled_points_local_np[i,2])
            p.Transform(trsf)
            sampled_points_world_np[i,:] = [p.X(), p.Y(), p.Z()]

    edge_vertices_world_list = []
    v_exp = TopExp_Explorer(edge, TopAbs_VERTEX)
    while v_exp.More():
        vertex = topods_Vertex(v_exp.Current())
        pt_local_occ = BRep_Tool.Pnt(vertex) # Vertex point, local to edge's geometry
        pt_world_occ = gp_Pnt(pt_local_occ.X(), pt_local_occ.Y(), pt_local_occ.Z())
        if apply_trsf:
            pt_world_occ.Transform(trsf)
        edge_vertices_world_list.append([pt_world_occ.X(), pt_world_occ.Y(), pt_world_occ.Z()])
        v_exp.Next()
    
    edge_vertices_world_np = np.array(edge_vertices_world_list, dtype=np.float64) if edge_vertices_world_list else None
    return sampled_points_world_np, edge_vertices_world_np


def main():
    parser = argparse.ArgumentParser(description="Visualize STEP file: exploded view, random face sample, random edge sample.")
    # Exploded view args
    parser.add_argument("step_file", help="Input STEP file path")
    parser.add_argument("graph_file", help="DGL graph .bin file path (placeholder, not used by this script)")
    parser.add_argument("--output_image", help="Output image for exploded view", default="render_exploded.png")
    parser.add_argument("--explode_distance", type=float, default=0.5, help="Explode distance")
    parser.add_argument("--face_deflection", type=float, default=0.05, help="Mesh deflection for faces")
    parser.add_argument("--opacity", type=float, default=0.9, help="Opacity for exploded faces")
    parser.add_argument("--exploded_top_bottom_color", type=str, default="LightGrey", help="Color for top/bottom faces in exploded view")
    parser.add_argument("--exploded_side_color", type=str, default="White", help="Color for side faces in exploded view")
    # Face sampling args
    parser.add_argument("--sample_face_output", type=str, default=None, help="Output image for random face sample")
    parser.add_argument("--sample_nu", type=int, default=10, help="Samples in U for face")
    parser.add_argument("--sample_nv", type=int, default=10, help="Samples in V for face")
    parser.add_argument("--sample_point_color", type=str, default="red", help="Color for sampled points on face")
    parser.add_argument("--sample_show_normals", action='store_true', help="Show normals for face samples")
    parser.add_argument("--sample_arrow_color", type=str, default="blue", help="Color for normal arrows")
    parser.add_argument("--sample_arrow_scale", type=float, default=0.05, help="Scale for normal arrows")
    parser.add_argument("--sampled_face_color", type=str, default="lightcyan", help="Color of the sampled face")
    parser.add_argument("--sampled_face_opacity", type=float, default=0.7, help="Opacity of the sampled face")
    # Edge sampling args
    parser.add_argument("--sample_edge_output", type=str, default=None, help="Output image for random edge sample")
    parser.add_argument("--sample_edge_n_points", type=int, default=10, help="Number of points to sample on edge")
    parser.add_argument("--sample_edge_point_color", type=str, default="green", help="Color for sampled points on edge")
    parser.add_argument("--sampled_edge_color", type=str, default="purple", help="Color of the sampled edge line")
    parser.add_argument("--sampled_edge_line_width", type=float, default=3, help="Line width for the sampled edge")
    parser.add_argument("--sampled_edge_vertex_color", type=str, default="black", help="Color for edge vertices")
    
    args = parser.parse_args()
    print("Starting visualization script...")
    shape = load_step_solid(args.step_file)
    print(f"STEP file loaded. Shape type: {type(shape)}")

    # --- Part 1: Exploded View ---
    print("\n--- Generating Exploded View ---")
    plotter_exploded = pv.Plotter(off_screen=True, window_size=[1024, 768])
    exp_faces_exploded = TopExp_Explorer(shape, TopAbs_FACE)
    face_idx_exploded = 0
    while exp_faces_exploded.More():
        face_idx_exploded += 1; current_topo_face = topods_Face(exp_faces_exploded.Current())
        adaptor = BRepAdaptor_Surface(current_topo_face, True); surf_geom = adaptor.Surface().Surface()
        if surf_geom is None: exp_faces_exploded.Next(); continue
        umin,umax,vmin,vmax = adaptor.FirstUParameter(),adaptor.LastUParameter(),adaptor.FirstVParameter(),adaptor.LastVParameter()
        u_mid,v_mid = (umin+umax)/2, (vmin+vmax)/2
        p,du,dv = gp_Pnt(),gp_Vec(),gp_Vec()
        try:
            surf_geom.D1(u_mid,v_mid,p,du,dv); normal_occ = du.Crossed(dv)
            if normal_occ.Magnitude()<1e-9: exp_faces_exploded.Next(); continue
            normal_occ.Normalize()
        except Exception: exp_faces_exploded.Next(); continue
        trans_dir = gp_Vec(normal_occ.X(),normal_occ.Y(),normal_occ.Z())
        if current_topo_face.Orientation()==TopAbs_REVERSED: trans_dir.Reverse()
        trans_vec = trans_dir.Scaled(args.explode_distance); trsf = gp_Trsf(); trsf.SetTranslation(trans_vec)
        tf_builder = BRepBuilderAPI_Transform(current_topo_face,trsf,True)
        if not tf_builder.IsDone(): exp_faces_exploded.Next(); continue
        tf_shape = tf_builder.Shape()
        try:
            verts, tris = mesh_exploded_face(tf_shape, args.face_deflection)
            if verts is not None and tris is not None:
                pvf = np.hstack([np.full((tris.shape[0],1),3,dtype=tris.dtype),tris])
                
                # Determine face color based on its original normal vector
                face_color = args.exploded_side_color  # Default to side color
                if abs(normal_occ.Z()) > 0.9:  # Check if normal is predominantly along Z-axis
                    face_color = args.exploded_top_bottom_color
                    
                plotter_exploded.add_mesh(pv.PolyData(verts,faces=pvf),color=face_color,opacity=args.opacity)
        except Exception: pass
        exp_faces_exploded.Next()
    print(f"Exploded face iteration complete. Processed: {face_idx_exploded}")
    plotter_exploded.camera_position = 'iso'
    try:
        if args.output_image.lower().endswith(".svg"):
            plotter_exploded.save_graphic(args.output_image)  # 使用 save_graphic 替代 export_svg
            print(f"Exploded view saved as SVG to {args.output_image}")
        else:
            plotter_exploded.screenshot(args.output_image)
            print(f"Exploded view saved to {args.output_image}")
    except Exception as e:
        print(f"Error saving exploded view: {e}")
    plotter_exploded.close()

    # --- Part 2: Random Face Sampling ---
    if args.sample_face_output:
        print(f"\n--- Generating Random Face Sample View (Output: {args.sample_face_output}) ---")
        all_faces = []
        face_explorer = TopExp_Explorer(shape, TopAbs_FACE)
        while face_explorer.More():
            all_faces.append(topods_Face(face_explorer.Current()))
            face_explorer.Next()
        
        if not all_faces: print("No faces to sample.")
        else:
            rand_face = random.choice(all_faces)
            plotter_face_sample = pv.Plotter(off_screen=True, window_size=[1024,768])
            face_v, face_t = mesh_exploded_face(rand_face, args.face_deflection)
            if face_v is not None and face_t is not None:
                pvf = np.hstack([np.full((face_t.shape[0],1),3,dtype=face_t.dtype),face_t])
                plotter_face_sample.add_mesh(pv.PolyData(face_v,faces=pvf), color=args.sampled_face_color, opacity=args.sampled_face_opacity)
            pts, nors = sample_face_uv(rand_face, args.sample_nu, args.sample_nv)
            if pts is not None:
                plotter_face_sample.add_points(pts, color=args.sample_point_color, point_size=8)
                if args.sample_show_normals and nors is not None:
                    arrows = [pv.Arrow(start=p,direction=n,scale=args.sample_arrow_scale) for p,n in zip(pts,nors) if np.linalg.norm(n)>1e-6]
                    if arrows: plotter_face_sample.add_mesh(pv.MultiBlock(arrows).combine(merge_points=False), color=args.sample_arrow_color)
            plotter_face_sample.camera_position = 'iso'; plotter_face_sample.reset_camera(render=False)
            try: plotter_face_sample.screenshot(args.sample_face_output); print(f"Face sample view saved to {args.sample_face_output}")
            except Exception as e: print(f"Error saving face sample: {e}")
            plotter_face_sample.close()
    else: print("\nSkipping random face sampling.")

    # --- Part 3: Random Edge Sampling ---
    if args.sample_edge_output:
        print(f"\n--- Generating Random Edge Sample View (Output: {args.sample_edge_output}) ---")
        all_edges = []
        edge_explorer = TopExp_Explorer(shape, TopAbs_EDGE)
        while edge_explorer.More():
            all_edges.append(topods_Edge(edge_explorer.Current()))
            edge_explorer.Next()

        if not all_edges:
            print("No edges found in the shape to sample. Skipping random edge sampling.")
        else:
            random_edge = random.choice(all_edges)
            print(f"Randomly selected one edge for sampling (out of {len(all_edges)} edges).")

            plotter_edge_sample = pv.Plotter(off_screen=True, window_size=[1024, 768])
            print("Random edge sample view plotter initialized.")

            sampled_edge_points, edge_vertices_coords = sample_edge_curve(random_edge, n_points=args.sample_edge_n_points)

            if sampled_edge_points is not None and len(sampled_edge_points) > 0:
                print(f"Successfully sampled {len(sampled_edge_points)} points from the edge.")
                if len(sampled_edge_points) > 1: 
                    plotter_edge_sample.add_mesh(pv.lines_from_points(sampled_edge_points), 
                                                 color=args.sampled_edge_color, 
                                                 line_width=args.sampled_edge_line_width)
                    print(f"Edge curve added to plotter with color '{args.sampled_edge_color}'.")
                elif len(sampled_edge_points) == 1: 
                     plotter_edge_sample.add_points(sampled_edge_points, color=args.sampled_edge_color, point_size=5) # This is for the edge itself if only one point

                # Modify this line for the sampled points on the edge
                plotter_edge_sample.add_points(sampled_edge_points, color="red", point_size=10) # Changed color and point_size
                print(f"Sampled points on edge added with color 'red' and size 10.")

                if edge_vertices_coords is not None and len(edge_vertices_coords) > 0:
                    plotter_edge_sample.add_points(edge_vertices_coords, color=args.sampled_edge_vertex_color, point_size=8)
                    print(f"Edge vertices ({len(edge_vertices_coords)}) added to plotter with color '{args.sampled_edge_vertex_color}'.")
            else:
                print("Failed to sample points from the selected edge or no points returned.")

            plotter_edge_sample.camera_position = 'iso'
            plotter_edge_sample.enable_auto_update = False
            plotter_edge_sample.reset_camera(render=False) # Adjust camera to fit the edge and points
            print("Edge sample view camera position set and reset.")
            try:
                plotter_edge_sample.screenshot(args.sample_edge_output)
                print(f"Random edge sample view render saved to {args.sample_edge_output}")
            except Exception as e:
                print(f"Error saving random edge sample view screenshot: {e}")
            plotter_edge_sample.close()
    else:
        print("\nSkipping random edge sampling as --sample_edge_output was not provided.")

    print("\nVisualization script finished.")

if __name__ == "__main__":
    main()