import numpy as np
import pyvista as pv
import os, json

class CenterLine:
    def __init__(self, path):
        self.path = path
        self.cl_json = json.load(open(path))
        self.points = []
        for control_point in self.cl_json['markups'][0]['controlPoints']:
            self.points.append([-control_point['position'][0], -control_point['position'][1], control_point['position'][2]])
        self.points = np.array(self.points)
        self.n = len(self.points)
        self.ids = [i for i in range(self.n)]
        
        self.is_good = (len(self.cl_json['markups'][0]['controlPoints']) > 0)
        
        if self.is_good:
            self.radiuss = np.array(self.cl_json['markups'][0]['measurements'][5]['controlPointValues'])  

            self.ids_from_points = {tuple(self.points[i]): i for i in range(self.n)}

            self.is_terminal = False
            self.terminal_id = -1
            self.OutIn_ids = []

            self.length = 0.
            for i in range(self.n - 1):
                self.length += np.linalg.norm(self.points[i] - self.points[i + 1])

            self.in_out = "not_in_out" # "in" or "out" or "not_in_out"

    def resample(self, target_n):
        if self.n <= 1: return  
        
        s = np.zeros(self.n)
        for i in range(1, self.n):
            s[i] = s[i-1] + np.linalg.norm(self.points[i] - self.points[i-1])

        sample_positions = np.linspace(0, self.length, target_n)
        
        new_points = []
        new_radii = []
        for pos in sample_positions:
            i = np.searchsorted(s, pos) - 1
            i = max(0, min(i, self.n-2))
            
            segment_length = s[i+1] - s[i]
            t = (pos - s[i]) / segment_length if segment_length != 0 else 0
            
            new_point = self.points[i] + t * (self.points[i+1] - self.points[i])
            new_radius = self.radiuss[i] + t * (self.radiuss[i+1] - self.radiuss[i])
            
            new_points.append(new_point)
            new_radii.append(new_radius)
        
        self.points = np.array(new_points)
        self.radiuss = np.array(new_radii)
        self.n = target_n
        self.ids = list(range(target_n))
        self.ids_from_points = {tuple(self.points[i]): i for i in range(self.n)}
    
    def set_terminal(self, terminal_p=None):
        self.is_terminal = True

        if not terminal_p == None:
            self.terminal_id = self.ids_from_points[tuple(terminal_p)]
            
            if self.terminal_id == 0: self.OutIn_ids = self.ids
            else: self.OutIn_ids = self.ids[::-1]

    def get_cut_point_id(self, min_radius=0.3, cut_ratio=0.33):
        OutIn_lengths = [0.]
        OutIn_points = [self.points[id] for id in self.OutIn_ids]
        for i in range(1, len(OutIn_points)):
            OutIn_lengths.append(OutIn_lengths[i - 1] + np.linalg.norm(OutIn_points[i] - OutIn_points[i - 1]))
        OutIn_lengths = np.array(OutIn_lengths)
        target_length_1 = self.length * cut_ratio
        point_id_1 = np.argmin(np.abs(OutIn_lengths - target_length_1))

        OutIn_radiuss = [self.radiuss[id] for id in self.OutIn_ids]
        radius_1 = OutIn_radiuss[point_id_1]
        if radius_1 < min_radius: 
            for i in range(point_id_1, len(OutIn_radiuss)):
                if OutIn_radiuss[i] >= min_radius:
                    point_id_1 = i
                    break
        point_id_2 = point_id_1 + 1
        cut_length = (OutIn_lengths[point_id_2] + OutIn_lengths[point_id_1]) / 2.

        return self.OutIn_ids[point_id_1], self.OutIn_ids[point_id_2], cut_length

class CenterLines:
    def __init__(self, centerline_dir, endpoints_path=None):
        self.centerline_dir = centerline_dir
        self.read_centerlines(centerline_dir)

        self.get_terminal_lines_ids()

        if endpoints_path != None:
            self.read_endpoints(endpoints_path)

    def read_centerlines(self, centerline_dir):
        centerlines = []
        for root, dirs, files in os.walk(centerline_dir):
            for file in files:
                if file.endswith(".json"):
                    file_path = os.path.join(root, file)
                    cl = CenterLine(file_path)
                    if cl.is_good:
                        centerlines.append(cl)
        
        self.lines = {}
        for id, cl in enumerate(centerlines):
            self.lines[id] = cl
            self.lines[id].resample(target_n=20)
    
    def get_terminal_lines_ids(self):
        self.all_terminal_points = {}
        for (id, cl) in self.lines.items():
            p0, p1 = tuple(cl.points[0]), tuple(cl.points[-1])
            if p0 not in self.all_terminal_points: self.all_terminal_points[p0] = []
            if p1 not in self.all_terminal_points: self.all_terminal_points[p1] = []
            self.all_terminal_points[p0].append(id)
            self.all_terminal_points[p1].append(id)
        
        self.terminal_ids = []
        for terminal_p in self.all_terminal_points:
            if len(self.all_terminal_points[terminal_p]) == 1:
                line_id = self.all_terminal_points[terminal_p][0]
                self.terminal_ids.append(line_id)
                self.lines[line_id].set_terminal(terminal_p)

    def read_endpoints(self, endpoints_path):
        with open(endpoints_path, 'r') as f:
            ep_json = json.load(f)

        control_points = ep_json['markups'][0]['controlPoints']

        self.ep_positions = []
        self.ep_selected = []

        for p in control_points:
            pos = [-p['position'][0], -p['position'][1], p['position'][2]]
            self.ep_positions.append(pos)
            self.ep_selected.append(p['selected'])

        ep_positions = np.array(self.ep_positions)

        for line_ids in self.terminal_ids:
            cl = self.lines[line_ids]
            terminal_p = np.array(cl.points[cl.terminal_id])

            dists = np.linalg.norm(ep_positions - terminal_p, axis=1)

            min_idx = np.argmin(dists)
            selected = self.ep_selected[min_idx]

            self.lines[line_ids].in_out = "out" if selected else "in"
        
        
def cut_stl_with_plane(mesh, plane_origin, direction_vector, terminal_p):
    plane_normal = direction_vector / np.linalg.norm(direction_vector)

    nocap_1 = mesh.clip(normal=plane_normal, origin=plane_origin)
    nocap_2 = mesh.clip(normal=-plane_normal, origin=plane_origin)

    if nocap_1.n_faces == 0: return nocap_2
    if nocap_2.n_faces == 0: return nocap_1

    components_1 = nocap_1.split_bodies().as_polydata_blocks()
    components_2 = nocap_2.split_bodies().as_polydata_blocks()

    dists_1 = [np.min(np.linalg.norm(c.points - terminal_p, axis=1)) for c in components_1]
    min_dist_1, closest_component_id_1 = np.min(dists_1), int(np.argmin(dists_1))
    
    dists_2 = [np.min(np.linalg.norm(c.points - terminal_p, axis=1)) for c in components_2]
    min_dist_2, closest_component_id_2 = np.min(dists_2), int(np.argmin(dists_2))

    if min_dist_1 < min_dist_2:
        components_1.pop(closest_component_id_1)
        merged_mesh = nocap_2
        for c in components_1: merged_mesh = merged_mesh.merge(c)
    else:
        components_2.pop(closest_component_id_2)
        merged_mesh = nocap_1
        for c in components_2: merged_mesh = merged_mesh.merge(c)
    
    components_3 = merged_mesh.split_bodies().as_polydata_blocks()
    merged_mesh = components_3[int(np.argmax([c.volume for c in components_3]))]

    return merged_mesh

def main():
    root_dir = r"root_dir\\"
    listdir = os.listdir(rf"{root_dir}\stl_1_artery_normal")
    listdir.sort()

    need_cut_list = []

    for filename in listdir:
        if not filename.replace(".stl", "") in need_cut_list: continue
        #if not filename.split("_patch")[0]+"_patch_0" in need_cut_list: continue

        centerline_dir = rf"{root_dir}\centerlines\{filename[:-4]}\\"
        endpoints_path = rf"{root_dir}\endpoints\{filename[:-4]}.json"
        stl_path = rf"{root_dir}\stl_1_artery_normal\{filename}"
        cut_dir = rf"{root_dir}\stl_2_cut\\"
        cut_path = cut_dir + filename
        os.makedirs(cut_dir, exist_ok=True)
        save_plane_origin_dir = rf"{root_dir}\cut_info\\"

        centerlines = CenterLines(centerline_dir, endpoints_path=endpoints_path)
        cut_info = []

        mesh = pv.read(stl_path)
        for id in centerlines.terminal_ids:
            cl = centerlines.lines[id]
            cl.points[:, :2] *= -1
            
            point_id_1, point_id_2, cut_length = cl.get_cut_point_id(min_radius=0.3, cut_ratio=0.02) # or 0.01 0.005
            print(f"centerline {id}, cut length {cut_length/cl.length*100:.1f}%: {cut_length:.2f}/{cl.length:.2f}, target radius: {cl.radiuss[point_id_1]:.4f}, startpoint: {point_id_1}, endpoint: {point_id_2}")
            p1, p2, terminal_p = cl.points[point_id_1], cl.points[point_id_2], cl.points[cl.terminal_id]
            plane_origin, direction_vector = (p1+p2)/2., p2-p1
            mesh = cut_stl_with_plane(mesh, plane_origin, direction_vector, terminal_p)
            #mesh.save(cut_path + f"_{id}.stl")

            cut_info.append({
                "plane_origin": plane_origin.tolist(),
                "in_out": cl.in_out
            })
        mesh.save(cut_path)
        
        print(f"finished, results saved {cut_path}")

        os.makedirs(save_plane_origin_dir, exist_ok=True)
        save_plane_origin_path = rf"{save_plane_origin_dir}\{filename[:-4]}.json"
        with open(save_plane_origin_path, "w") as f:
            json.dump(cut_info, f, indent=4)

if __name__ == "__main__":
    main()