from matplotlib import pyplot as plt
import numpy as np
import open3d as o3d
import shapely
import time
import torch

from cgal_triangulations import triangulate_polygon


def color_to_instance(color: np.ndarray) -> np.ndarray:
    """receives a 2D array of color values of shape (N, 3) and returns a 1D array of instance values of shape (N,). Each color maps to an unique instance"""
    if len(color) == 0:
        return np.array([])
    if color.max() < 2:
        color = color * 255
    return (color[:, 0] + 256 * color[:, 1] + 256 * 256 * color[:, 2]).astype(int)

def instance_to_color(instance: np.ndarray, scale_to=None) -> np.ndarray:
    """receives a 1D array of instance values of shape (N,) and returns a 2D array of color values of shape (N, 3)"""
    assert np.max(instance) < 256 * 256 * 256, "Instance values are too large"
    assert np.min(instance) >= 0, "Instance values are negative"
    if scale_to is None:
        # instance = ((instance / np.max(instance)) * 255 * 255 * 255).astype(int)
        instance = instance.astype(int)
    else:
        instance = ((instance / scale_to) * 255 * 255 * 255).astype(int)
    color = np.zeros((len(instance), 3), dtype=np.uint8)
    color[:, 0] = instance % 256
    color[:, 1] = (instance // 256) % 256
    color[:, 2] = (instance // (256 * 256)) % 256
    return color

def project_points_to_3D_plane(points, plane_eq):
    """Project points to a plane. https://docs.pyvista.org/examples/98-common/project-points-tessellate.html"""
    [a, b, c, d] = plane_eq
    plane_origin = np.array([0, 0, -d/c])
    plane_normal = np.array([a, b, c])
    vec = points - plane_origin
    dist = np.dot(vec, plane_normal)
    return points - np.outer(dist, plane_normal), np.abs(dist)


def compute_signed_distances_to_3D_plane(points, plane_eq):
    """Compute the signed distances of the points to the plane. Generated by Copilot."""
    [a, b, c, d] = plane_eq
    plane_origin = np.array([0, 0, -d/(c + 1e-10)], dtype=np.float32)
    plane_normal = np.array([a, b, c], dtype=np.float32)
    vec = points - plane_origin
    return np.dot(vec, plane_normal)


def compute_signed_distances_to_3D_plane_batched(points, plane_eqs):
    """same as above but for multiple planes"""
    [A, B, C, D] = plane_eqs.T
    plane_origins = np.zeros((len(plane_eqs), 3), dtype=np.float32)
    plane_origins[:, 2] = -D / (C + 1e-10)
    plane_normals = np.array([A, B, C], dtype=np.float32).T
    return ((points[:, None] - plane_origins) * plane_normals[None]).sum(axis=-1)


def adjacency_list_to_connected_contour(adjacency_list: dict) -> list:
    vertices_that_still_have_neighbors = [k for k in adjacency_list.keys() if len(adjacency_list[k]) > 0]
    if len(vertices_that_still_have_neighbors) == 0:
        return []
    vertex_list = [vertices_that_still_have_neighbors[0]]
    # done[vertex_list[0]] = True
    while True:
        # assert len(adjacency_list[vertex_list[-1]]) % 2 == 0, f"Vertex {vertex_list[-1]} has {len(adjacency_list[vertex_list[-1]])} adjacent vertices"
        current_vertex = vertex_list[-1]
        if len(adjacency_list[current_vertex]) == 0:
            # assert vertex_list[0] == vertex_list[-1] or len(vertex_list) == 2, f"Vertex {vertex_list[-1]} has no neighbors but is not the start vertex"
            vertices_that_still_have_neighbors = [k for k in adjacency_list.keys() if len(adjacency_list[k]) > 0 and k in vertex_list]
            if len(vertices_that_still_have_neighbors) > 0:
                # roll the vertex list so that it starts with a vertex that still has neighbors
                # print(vertex_list, current_vertex)
                current_vertex = vertices_that_still_have_neighbors[0]
                vertex_list = vertex_list[vertex_list.index(current_vertex):] + vertex_list[1:vertex_list.index(current_vertex)]
                # print(vertex_list) 
            else:
                # print(f"Breaking because of vertex {vertex_list[-1]} that has no neighbors")
                break 
        neighbor = adjacency_list[current_vertex].pop()
        # # also remove the edge from the neighbor to the current vertex
        adjacency_list[neighbor].remove(current_vertex)
        # # print("choosing neighbor", neighbor)
        # while done[neighbor] and len(adjacency_list[current_vertex]) > 0:
        #     neighbor = adjacency_list[current_vertex].pop()
        #     adjacency_list[neighbor].remove(current_vertex)
            # print("switching choosing neighbor", neighbor)
        # if done[neighbor]:
        #     print(f"Breaking because of vertex {current_vertex} that only has neighbors {adjacency_list[vertex_list[-1]]}, {neighbor}")
        #     break
        
        # done[current_vertex] = len([v for v in adjacency_list[current_vertex] if not done[v]]) == 0
        # if neighbor == vertex_list[0]:
        #     print(f"Breaking because of vertex {vertex_list[-1]} that only has neighbors {adjacency_list[vertex_list[-1]]}, {neighbor}")
        #     break
        vertex_list.append(neighbor)

    return vertex_list

def extract_largest_connected_component(mesh, only_count_mask=None):
    """Extract the largest connected component of a mesh. If only_count_mask is given, chooses the component that has the largest overlap with the mask."""
    triangle_clusters, cluster_n_triangles, cluster_area = mesh.cluster_connected_triangles()
    triangles = np.array(mesh.triangles)
    if only_count_mask is not None:
        cluster_n_triangles = [np.sum(only_count_mask[np.unique(triangles[np.where(triangle_clusters == i)[0]])]) for i in range(len(cluster_n_triangles))]
    selected_triangles = np.array(mesh.triangles)[np.where(triangle_clusters == np.argmax(cluster_n_triangles))[0]]
    selected_vertices = np.unique(selected_triangles)
    assert selected_vertices.max() < len(mesh.vertices)
    return mesh.select_by_index(selected_vertices)



def get_contour_of_planar_mesh(mesh, plot=False, mesh_is_index_colored=False):
    """Get the edge vertices of the instance by iterating over the faces (any vertex that shares a face with a vertex from another instance is an edge)."""
    # encode the original indices as vertex colors
    if not mesh_is_index_colored:
        mesh = o3d.geometry.TriangleMesh(mesh)
        mesh.vertex_colors = o3d.utility.Vector3dVector(instance_to_color(np.arange(0, len(mesh.vertices))))
    else:
        mesh = o3d.geometry.TriangleMesh(mesh)    
    # mesh.compute_adjacency_list()
    # original_adjacency_list = mesh.adjacency_list
    # # extract largest connected component
    # triangle_clusters, cluster_n_triangles, cluster_area = mesh.cluster_connected_triangles()
    # selected_triangles = np.array(mesh.triangles)[np.where(triangle_clusters == np.argmax(cluster_n_triangles))[0]]
    # selected_vertices = np.unique(selected_triangles)
    # assert selected_vertices.max() < len(mesh.vertices)
    # if len(selected_vertices) < 4:
    #     return None, None
    # component = mesh.select_by_index(selected_vertices)
    component = mesh
    # submesh = submesh.select_by_index(np.where(triangle_clusters == largest_cluster_ix)[0])
    # count how many times each edge is referenced
    # edge_reference_count = {}
    # def add_or_increment_edge(edge):
    #     edge = tuple(sorted(edge))
    #     if edge in edge_reference_count:
    #         edge_reference_count[edge] += 1
    #     else:
    #         edge_reference_count[edge] = 1

    # # triangles = {tuple(sorted({v for v in t})) for t in component.triangles}
    # for face in component.triangles:
    #     add_or_increment_edge((face[0], face[1]))
    #     add_or_increment_edge((face[1], face[2]))
    #     add_or_increment_edge((face[2], face[0]))

    triangles = np.array(component.triangles)
    all_edges = np.concatenate([triangles[:, [0, 1]], triangles[:, [1, 2]], triangles[:, [2, 0]]])
    all_edges = np.sort(all_edges, axis=1)
    unique_edges, edge_counts = np.unique(all_edges, return_counts=True, axis=0)

    edges = unique_edges[edge_counts == 1]
        
    # now only keep the edges that are referenced once
    # slow_edges = [edge for edge, count in edge_reference_count.items() if count == 1]
    # assert set(map(tuple, edges)) == set(slow_edges)
    adjacency_list = {}
    for e in edges:
        if e[0] in adjacency_list:
            adjacency_list[e[0]].add(e[1])
        else:
            adjacency_list[e[0]] = {e[1]}
        if e[1] in adjacency_list:
            adjacency_list[e[1]].add(e[0])
        else:
            adjacency_list[e[1]] = {e[0]}

    # vertex_lists = []
    # color vertex 29 green
    vertex_lists = [adjacency_list_to_connected_contour(adjacency_list)]

    while len([k for k in adjacency_list.keys() if len(adjacency_list[k]) > 0]) > 0:
        vertex_lists.append(adjacency_list_to_connected_contour(adjacency_list))
    
    # select the longest contour
    vertex_lists = sorted(vertex_lists, key=lambda x: -len(x))

    # vertex_list = sorted(vertex_lists, key=lambda x: -len(x))[0]
    # assert len(np.unique(vertex_list)) == len(np.unique([u for k in edges for u in k])), "The contour is disconnected"
    colors = np.array(component.vertex_colors)
    
    if plot:
        for vertex_list in vertex_lists:
            polygon = shapely.geometry.Polygon(np.array(mesh.vertices)[vertex_list])
            x,y = polygon.exterior.xy
            plt.gca().set_aspect('equal', adjustable='box')
            plt.plot(x,y)
        plt.show()

    vertex_lists = [color_to_instance(colors[vertex_list]) for vertex_list in vertex_lists if len(vertex_list) >= 3]
    
    if len(vertex_lists) == 0 or sum(map(len, vertex_lists)) == 0:
        print(f"Warning: no contour found")
        return np.zeros(len(mesh.vertices), dtype=bool), []

    # @TODO
    # from shapely import make_valid
    # # @TODO upgrade to shapely 2.10
    # contour_polygon = make_valid(shapely.geometry.Polygon(floor_contour_vertices))
    # while isinstance(contour_polygon, shapely.geometry.MultiPolygon) or isinstance(contour_polygon, shapely.geometry.GeometryCollection):
    #     contour_polygon = contour_polygon.geoms[0]
    is_contour = np.isin(np.arange(0, len(mesh.vertices)),  np.concatenate(vertex_lists))
    return is_contour, list(map(lambda x: np.array(x).astype(int), vertex_lists))




def project_points_3d_to_2d_plane_based(points: np.ndarray, plane_equation: np.ndarray):
    """Takes points in 3D on a 3D plane. Returns 2D coordinates in the plane coordinate system correspondin to the points"""
    # Normalize the plane normal
    n = plane_equation[:3] / np.linalg.norm(plane_equation[:3])
    
    # Define an arbitrary origin on the plane
    r_O = np.array([0, 0, -plane_equation[3] / n[2]])
    
    # Define two orthogonal directions in the plane
    e_1 = np.array([1, 0, -n[0] / n[2]])
    e_1 /= np.linalg.norm(e_1)
    e_2 = np.cross(n, e_1)
    e_2 /= np.linalg.norm(e_2)
    
    # # Project points onto the plane
    # projected_points = []
    # for r_P in points:
    #     r_P = np.array(r_P)
    #     r_P_minus_r_O = r_P - r_O
    #     t_1 = np.dot(e_1, r_P_minus_r_O)
    #     t_2 = np.dot(e_2, r_P_minus_r_O)
    #     projected_points.append([t_1, t_2])

    # do the same vectorized
    r_P_minus_r_O = points - r_O
    t_1 = e_1 @ r_P_minus_r_O.T
    t_2 = e_2 @ r_P_minus_r_O.T
    projected_points = np.array([t_1, t_2]).T
    # assert np.allclose(projected_points_vec, np.array(projected_points))
    
    return np.array(projected_points)

def project_points_2d_to_3d_plane_based(points_2d: np.ndarray, plane_equation: np.ndarray):
    """inverse transform of project_points_3d_to_2d_plane_based"""
    n = plane_equation[:3] / np.linalg.norm(plane_equation[:3])
    r_O = np.array([0, 0, -plane_equation[3] / n[2]])
    e_1 = np.array([1, 0, -n[0] / n[2]])
    e_1 /= np.linalg.norm(e_1)
    e_2 = np.cross(n, e_1)
    e_2 /= np.linalg.norm(e_2)
    
    points_3d = []
    for t_1, t_2 in points_2d:
        r_P = r_O + t_1 * e_1 + t_2 * e_2
        points_3d.append(r_P)
    
    return np.array(points_3d)



def distances_to_line_segment(A, B, points):
    """Compute the distances of the points to the line segment AB."""
    distances = []
    for point in points:
        # compute the projection of the point on the line
        AP = point - A
        AB = B - A
        dot = np.dot(AP, AB)
        norm = np.dot(AB, AB)
        t = np.clip(dot / norm, 0, 1)
        projection = A + t * AB
        # compute the distance to the projection
        distance = np.linalg.norm(point - projection)
        distances.append(distance)
    return np.array(distances)

def lineseg_dists(p, a, b):
    """https://stackoverflow.com/questions/54442057/calculate-the-euclidian-distance-between-an-array-of-points-to-a-line-segment-in"""
    # Handle case where p is a single point, i.e. 1d array.
    p = np.atleast_2d(p)

    # TODO for you: consider implementing @Eskapp's suggestions
    if np.all(a == b):
        return np.linalg.norm(p - a, axis=1)

    # normalized tangent vector
    d = np.divide(b - a, np.linalg.norm(b - a))

    # signed parallel distance components
    s = np.dot(a - p, d)
    t = np.dot(p - b, d)

    # clamped parallel distance
    h = np.maximum.reduce([s, t, np.zeros(len(p))])

    # perpendicular distance component, as before
    # note that for the 3D case these will be vectors
    c = np.cross(p - a, d)

    # use hypot for Pythagoras to improve accuracy
    return np.hypot(h, c)


def polygon_to_mesh(polygon : dict):
    if "vertices_closed_2d" not in polygon:
        vertices_2d = project_points_3d_to_2d_plane_based(polygon['vertices_closed_3d'], polygon['plane_equation'])
    else:
        vertices_2d = polygon['vertices_closed_2d']
    new_vertices_2d, triangles_2d, triangle_ixes = triangulate_polygon(vertices_2d)
    new_vertices_3d = project_points_2d_to_3d_plane_based(new_vertices_2d, polygon['plane_equation'])

    new_mesh = o3d.geometry.TriangleMesh()
    new_mesh.vertices = o3d.utility.Vector3dVector(new_vertices_3d)
    new_mesh.triangles = o3d.utility.Vector3iVector(triangle_ixes)
    new_mesh.vertex_colors = o3d.utility.Vector3dVector(np.array([polygon['color']] * len(new_vertices_3d)).astype(float))

    return new_mesh




def compute_closest_triangle_to_points_o3d(vertices: torch.Tensor, triangles: torch.Tensor, query_points: torch.Tensor) -> torch.Tensor:
    """Computes the closest point to a mesh using Open3D."""
    t = time.time()
    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices.detach().cpu().numpy().astype(np.float64))
    mesh.triangles = o3d.utility.Vector3iVector(triangles.detach().cpu().numpy().astype(np.int32))
    
    scene = o3d.t.geometry.RaycastingScene()
    scene.add_triangles(o3d.t.geometry.TriangleMesh.from_legacy(mesh))

    INVALID_ID = o3d.t.geometry.RaycastingScene.INVALID_ID
    query_points_o3d = o3d.core.Tensor(query_points.detach().cpu().numpy().astype(np.float64), dtype=o3d.core.Dtype.Float32)

    ans = scene.compute_closest_points(query_points_o3d)

    # draw mesh and query points
    # pcd = o3d.geometry.PointCloud()
    # pcd.points = o3d.utility.Vector3dVector(query_points.detach().cpu().numpy())
    # o3d.visualization.draw_geometries([mesh])

    geometry_ids = ans['geometry_ids'].numpy() # per point: which mesh did we hit (all zeros)
    assert np.all((geometry_ids == 0) | (geometry_ids == o3d.t.geometry.RaycastingScene.INVALID_ID)), "All rays should hit the same mesh"
    primitive_ids = ans['primitive_ids'].numpy() # per point: which triangle did we hit
    assert ((primitive_ids >= 0).all() & (primitive_ids < len(mesh.triangles)) | (primitive_ids == INVALID_ID)).all(), "All rays should hit a triangle"
    primitive_uvs = ans['primitive_uvs'].numpy() # per point: uv coordinates of the closest triangle

    # print(f"Computed closest triangle to points in {time.time() - t} seconds.  primitive_ids: {primitive_ids[:100]}")
    
    primitive_uvs = torch.from_numpy(primitive_uvs).to(vertices.device)
    is_inside = (primitive_uvs != 0).all(dim=-1)

    primitive_ids = torch.from_numpy(primitive_ids.astype(int)).to(vertices.device)


    return primitive_ids, is_inside


def compute_distance_to_mesh(vertices: torch.Tensor, triangles: torch.Tensor, query_points: torch.Tensor) -> torch.Tensor:
    """Computes the distance to a mesh using Open3D."""
    t = time.time()
    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices.detach().cpu().numpy())
    mesh.triangles = o3d.utility.Vector3iVector(triangles.detach().cpu().numpy())
    
    scene = o3d.t.geometry.RaycastingScene()
    scene.add_triangles(o3d.t.geometry.TriangleMesh.from_legacy(mesh))

    INVALID_ID = o3d.t.geometry.RaycastingScene.INVALID_ID
    query_points_o3d = o3d.core.Tensor(query_points.detach().cpu().numpy(), dtype=o3d.core.Dtype.Float32)

    # We compute the closest point on the surface for the point at position [0,0,0].
    distances = scene.compute_distance(query_points_o3d)

    distances = torch.from_numpy(distances.numpy()).to(vertices.device)

    return distances




@torch.no_grad()
def compute_ray_mesh_intersections_ray_tracing(vertices: torch.Tensor, triangles: torch.Tensor, ray_origins: torch.Tensor, ray_dests: torch.Tensor, margin : float = 0) -> torch.Tensor:
    t = time.time()

    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices.detach().cpu().numpy().astype(np.float64))
    mesh.triangles = o3d.utility.Vector3iVector(triangles.detach().cpu().numpy().astype(np.int32))

    scene = o3d.t.geometry.RaycastingScene()
    scene.add_triangles(o3d.t.geometry.TriangleMesh.from_legacy(mesh))
    
    INVALID_ID = o3d.t.geometry.RaycastingScene.INVALID_ID
    rays = torch.concatenate([ray_origins, ray_dests-ray_origins], dim=-1).float().detach().cpu().numpy()

    ans = scene.cast_rays(rays)

    geometry_ids = ans['geometry_ids'].numpy() # per ray: which mesh did we hit (all zeros)
    assert np.all((geometry_ids == 0) | (geometry_ids == o3d.t.geometry.RaycastingScene.INVALID_ID)), "All rays should hit the same mesh"
    t_hit = ans['t_hit'].numpy() # per ray: at what distance did we hit the triangle
    primitive_ids = ans['primitive_ids'].numpy() # per ray: which triangle did we hit
    assert ((primitive_ids >= 0).all() & (primitive_ids < len(mesh.triangles)) | (primitive_ids == INVALID_ID)).all(), "All rays should hit a triangle"
    
    t_hit = torch.from_numpy(t_hit).to(vertices.device)
    # normals = torch.from_numpy(np.array(ans["primitive_normals"])).to(vertices.device)
    # ray_direction = (ray_dests - ray_origins)
    # cosine_sim_ray_vs_normal = (ray_direction * normals).sum(dim=-1) / torch.clamp(ray_direction.norm(dim=-1) * normals.norm(dim=-1), min=1e-6)
    # angles_between_ray_and_normals = torch.acos(torch.clamp(cosine_sim_ray_vs_normal, -1 + 1e-6, 1 - 1e-6))
    # angles_between_ray_and_normals = torch.rad2deg(angles_between_ray_and_normals)
    does_intersect = torch.from_numpy(primitive_ids != INVALID_ID).to(vertices.device)
    does_intersect = does_intersect & (t_hit < 1)
    primitive_ids = torch.from_numpy(primitive_ids.astype(int)).to(vertices.device)
    intersection_points = (ray_origins + (ray_dests - ray_origins) * t_hit[..., None])
    
    len_of_overlap = (intersection_points - ray_dests).norm(dim=-1)

    does_intersect = (does_intersect & (len_of_overlap >= margin)).bool()



    # print(f"Computed ray mesh intersections in {time.time() - t} seconds.")

    return does_intersect, primitive_ids, intersection_points, len_of_overlap

