import os, slicer, shutil, SimpleITK as sitk, vtk
import ExtractCenterline
import datetime

t1 = datetime.datetime.now()
t2 = datetime.datetime.now()

def extract_endpoints(filename):
    model_dir = r"model_dir\stl_1\\"

    modelNode = slicer.util.loadModel(model_dir + filename)
    modelDisplayNode = modelNode.GetDisplayNode()
    modelDisplayNode.SetOpacity(0.4)

    extractLogic = ExtractCenterline.ExtractCenterlineLogic()
    inputSurfacePolyData = modelNode.GetPolyData()

    # Auto-detect the endpoints
    endPointsMarkupsNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", filename.replace(".stl", ""))
    networkPolyData = extractLogic.extractNetwork(inputSurfacePolyData, endPointsMarkupsNode)
    endpointPositions = extractLogic.getEndPoints(networkPolyData, startPointPosition=None)
    endPointsMarkupsNode.RemoveAllControlPoints()
    for position in endpointPositions:
        endPointsMarkupsNode.AddControlPoint(vtk.vtkVector3d(position))
    return endPointsMarkupsNode

def load_endpoints(filename):
    global t1
    t1 = datetime.datetime.now()

    model_dir = r"model_dir\stl_1_artery_normal\\"

    modelNode = slicer.util.loadModel(model_dir + filename)
    modelDisplayNode = modelNode.GetDisplayNode()
    modelDisplayNode.SetOpacity(0.4)

    ViewLayoutCollection = slicer.mrmlScene.GetNodesByClass('vtkMRMLLayoutNode')
    ViewLayoutCollection.InitTraversal()
    ViewLayout = ViewLayoutCollection.GetNextItemAsObject()
    ViewLayout.SetViewArrangement(4)

    # Auto-detect the endpoints
    
    ref_endpoints_dir = r"ref_endpoints_dir\endpoints\\"
    if not os.path.exists(ref_endpoints_dir + filename.replace(".stl", ".json")):
        ref_endpoints_dir = r"ref_endpoints_dir_2\endpoints\\"
        if not os.path.exists(ref_endpoints_dir + filename.replace(".stl", ".json")):
            print("[Warning] No reference endpoints found, extracting endpoints manually...")
            extractLogic = ExtractCenterline.ExtractCenterlineLogic()
            inputSurfacePolyData = modelNode.GetPolyData()

            # Auto-detect the endpoints
            endPointsMarkupsNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsFiducialNode", filename.replace(".stl", ""))
            networkPolyData = extractLogic.extractNetwork(inputSurfacePolyData, endPointsMarkupsNode)
            endpointPositions = extractLogic.getEndPoints(networkPolyData, startPointPosition=None)
            endPointsMarkupsNode.RemoveAllControlPoints()
            for position in endpointPositions:
                endPointsMarkupsNode.AddControlPoint(vtk.vtkVector3d(position))
            return endPointsMarkupsNode
        
    endPointsMarkupsNode = slicer.util.loadMarkups(ref_endpoints_dir + filename.replace(".stl", ".json"))
    return endPointsMarkupsNode

def save_centerlines(endPointsMarkupsNode):
    name = endPointsMarkupsNode.GetName()
    name = name.split("_patch_")[0] + "_patch_" + name.split("_patch_")[1][0]

    centerlines_dir = r"centerlines_dir\centerlines\\" + name + "\\"
    for i in range(0, 20):
        list_centerlineCurveNode = list(slicer.mrmlScene.GetNodesByName(f"Centerline curve ({i})"))
        if len(list_centerlineCurveNode) == 0: continue
        slicer.util.saveNode(list_centerlineCurveNode[0], centerlines_dir + f"Centerline curve ({i}).json")
    
    for j in range(1, 10):
        for i in range(0, 20):
            list_centerlineCurveNode = list(slicer.mrmlScene.GetNodesByName(f"Centerline curve_{j} ({i})"))
            if len(list_centerlineCurveNode) == 0: continue
            slicer.util.saveNode(list_centerlineCurveNode[0], centerlines_dir + f"Centerline curve_{j} ({i}).json")

    for i in range(0, 20):
        list_centerlineCurveNode = list(slicer.mrmlScene.GetNodesByName(f"Network curve ({i})"))
        if len(list_centerlineCurveNode) == 0: continue
        slicer.util.saveNode(list_centerlineCurveNode[0], centerlines_dir + f"Network curve ({i}).json")

    endpoints_dir = r"endpoints_dir\endpoints\\"
    os.makedirs(endpoints_dir, exist_ok=True)
    slicer.util.saveNode(endPointsMarkupsNode, endpoints_dir + name + ".json")
    slicer.mrmlScene.Clear()

    global t2
    t2 = datetime.datetime.now()

    listdir = os.listdir(r"listdir\stl_1_artery_normal\\")
    listdir.sort()
    
    for i in range(len(listdir)-1):
        if listdir[i] == name + ".stl":
            print(f"finished [{i+1}/{len(listdir)}] (process {100.*float(i+1)/(len(listdir)+1):.1f}%), time used {(t2-t1).seconds//60}m{(t2-t1).seconds%60}s, unfinish {len(listdir)-i-1}, next is: ", listdir[i+1])
            break
endPointsMarkupsNode = extract_endpoints("test.stl")
#endPointsMarkupsNode = load_endpoints("test.stl")
save_centerlines(endPointsMarkupsNode)
