import argparse
import pathlib
import signal
from multiprocessing.pool import Pool
from itertools import repeat

import dgl
import numpy as np
import torch
from occwl.graph import face_adjacency
from occwl.io import load_step
from occwl.uvgrid import ugrid, uvgrid
from tqdm import tqdm


def build_graph(solid, curv_num_u_samples, surf_num_u_samples, surf_num_v_samples):
    # solid.set_transform_to_identity()
    # Build face adjacency graph with B-rep entities as node and edge features
    graph = face_adjacency(solid)

    # Compute the UV-grids for faces
    graph_face_feat = []
    for face_idx in graph.nodes:
        # Get the B-rep face
        face = graph.nodes[face_idx]["face"]
        # Compute UV-grids
        points = uvgrid(
            face, method="point", num_u=surf_num_u_samples, num_v=surf_num_v_samples
        )
        normals = uvgrid(
            face, method="normal", num_u=surf_num_u_samples, num_v=surf_num_v_samples
        )
        visibility_status = uvgrid(
            face, method="visibility_status", num_u=surf_num_u_samples, num_v=surf_num_v_samples
        )
        mask = np.logical_or(visibility_status == 0, visibility_status == 2)  # 0: Inside, 1: Outside, 2: On boundary
        # Concatenate channel-wise to form face feature tensor
        face_feat = np.concatenate((points, normals, mask), axis=-1)
        graph_face_feat.append(face_feat)
    graph_face_feat = np.asarray(graph_face_feat)

    # Compute the U-grids for edges
    graph_edge_feat = []
    for edge_idx in graph.edges:
        # Get the B-rep edge
        edge = graph.edges[edge_idx]["edge"]
        # Ignore dgenerate edges, e.g. at apex of cone
        if not edge.has_curve():
            continue
        # Compute U-grids
        points = ugrid(edge, method="point", num_u=curv_num_u_samples)
        tangents = ugrid(edge, method="tangent", num_u=curv_num_u_samples)
        # Concatenate channel-wise to form edge feature tensor
        edge_feat = np.concatenate((points, tangents), axis=-1)
        graph_edge_feat.append(edge_feat)
    graph_edge_feat = np.asarray(graph_edge_feat)

    # Convert face-adj graph to DGL format
    edges = list(graph.edges)
    src = [e[0] for e in edges]
    dst = [e[1] for e in edges]
    dgl_graph = dgl.graph((src, dst), num_nodes=len(graph.nodes))
    dgl_graph.ndata["x"] = torch.from_numpy(graph_face_feat)
    dgl_graph.edata["x"] = torch.from_numpy(graph_edge_feat)
    return dgl_graph


def process_one_file(arguments):
    fn, args = arguments
    fn_stem = fn.stem
    output_path = pathlib.Path(args.output)
    try:
        solid = load_step(fn)[0]  # Assume there's one solid per file
        graph = build_graph(
            solid, args.curv_u_samples, args.surf_u_samples, args.surf_v_samples
        )
        dgl.data.utils.save_graphs(str(output_path.joinpath(fn_stem + ".bin")), [graph])
        return True, fn
    except Exception as e:
        return False, (fn, str(e))


def initializer():
    """Ignore CTRL+C in the worker process."""
    signal.signal(signal.SIGINT, signal.SIG_IGN)


def process(args):
    input_path = pathlib.Path(args.input)
    output_path = pathlib.Path(args.output)
    if not output_path.exists():
        output_path.mkdir(parents=True, exist_ok=True)
    step_files = list(input_path.glob("*.st*p"))
    
    # 创建错误日志文件
    error_log = output_path.joinpath("processing_errors.log")
    
    pool = Pool(processes=args.num_processes, initializer=initializer)
    succeeded = 0
    failed = 0
    failed_files = []
    
    try:

        
        results = list(tqdm(pool.imap(process_one_file, zip(step_files, repeat(args))), total=len(step_files)))
        for status, info in results:
            if status:
                succeeded += 1
            else:
                failed += 1
                failed_file, error_msg = info
                failed_files.append((str(failed_file), error_msg))
    except KeyboardInterrupt:
        pool.terminate()
        pool.join()
        
    # 写入错误日志
    if failed > 0:
        with open(error_log, 'w') as f:
            f.write(f"Failed files: {failed}\n")
            for failed_file, error_msg in failed_files:
                f.write(f"{failed_file}: {error_msg}\n")
                
    print(f"Processing complete. Succeeded: {succeeded}, Failed: {failed}")
    if failed > 0:
        print(f"See {error_log} for details on failed files")


def main():
    parser = argparse.ArgumentParser(
        "Convert solid models to face-adjacency graphs with UV-grid features"
    )
    parser.add_argument("input", type=str, help="Input folder of STEP files")
    parser.add_argument("output", type=str, help="Output folder of DGL graph BIN files")
    parser.add_argument(
        "--curv_u_samples", type=int, default=10, help="Number of samples on each curve"
    )
    parser.add_argument(
        "--surf_u_samples",
        type=int,
        default=10,
        help="Number of samples on each surface along the u-direction",
    )
    parser.add_argument(
        "--surf_v_samples",
        type=int,
        default=10,
        help="Number of samples on each surface along the v-direction",
    )
    parser.add_argument(
        "--num_processes",
        type=int,
        default=8,
        help="Number of processes to use",
    )
    args = parser.parse_args()
    process(args)


if __name__ == "__main__":
    main()