import os
import glob
import traceback
from multiprocessing import Process, cpu_count

from OCC.Core.STEPControl import STEPControl_Reader
from OCC.Core.IFSelect import IFSelect_RetDone
from OCC.Core.Bnd import Bnd_Box
from OCC.Core.BRepBndLib import brepbndlib_Add
from OCC.Core.Quantity import Quantity_Color, Quantity_TOC_RGB
from OCC.Core.AIS import AIS_Shape
from OCC.Extend.DataExchange import read_step_file_with_names_colors
from PIL import Image

INPUT_ROOT = "../data/s5_step"
OUTPUT_ROOT = "../data/s6_img"
num_processes = 1

# === Recursively find all .step files ===
step_files = glob.glob(os.path.join(INPUT_ROOT, "**", "*.step"), recursive=True)
step_to_output_map = {}
for step_path in step_files:
    rel_path = os.path.relpath(step_path, INPUT_ROOT)
    out_path = os.path.join(OUTPUT_ROOT, os.path.splitext(rel_path)[0] + ".jpg")
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    step_to_output_map[step_path] = out_path

def main_process(process_id):
    display_num = 99 + process_id
    os.environ["DISPLAY"] = f":{display_num}"
    os.system(f"Xvfb :{display_num} -screen 0 1024x768x24 >/dev/null 2>&1 &")

    from OCC.Display.SimpleGui import init_display
    display, start_display, _, _ = init_display()
    display.View.SetBgGradientColors(Quantity_Color(1.0, 1.0, 1.0, Quantity_TOC_RGB),
                                     Quantity_Color(1.0, 1.0, 1.0, Quantity_TOC_RGB), 2, True)

    all_step_paths = list(step_to_output_map.keys())
    for index in range(process_id, len(all_step_paths), num_processes):
        step_file = all_step_paths[index]
        output_img_path = step_to_output_map[step_file]
        if os.path.isfile(output_img_path):
            continue
        try:
            # --- Load STEP file ---
            step_reader = STEPControl_Reader()
            status = step_reader.ReadFile(step_file)
            if status != IFSelect_RetDone:
                print(f"Error reading file: {step_file}")
                continue
            step_reader.TransferRoots()
            shape = step_reader.OneShape()
            shapes_labels_colors = read_step_file_with_names_colors(step_file)

            # --- Display and Render ---
            for shpt_lbl_color in shapes_labels_colors:
                label, c = shapes_labels_colors[shpt_lbl_color]
                color = Quantity_Color(c.Red() * 0.25, c.Green() * 0.25, c.Blue() * 0.25, Quantity_TOC_RGB)
                ais_shape = AIS_Shape(shpt_lbl_color)
                ais_shape.SetColor(color)
                display.Context.SetTransparency(ais_shape, 0.0, False)
                display.Context.Display(ais_shape, False)
            display.FitAll()
            display.View.Dump(output_img_path)
            display.EraseAll()
            print(f"Rendered {output_img_path}")
        except Exception as e:
            print(f"Failed: {step_file}")
            traceback.print_exc()
            display.EraseAll()
            continue

if __name__ == "__main__":
    processes = []
    for i in range(num_processes):
        process = Process(target=main_process, args=(i,))
        processes.append(process)
        process.start()
    for process in processes:
        process.join()
