import numpy as np
import open3d as o3d
import trimesh

from deptheval.utils.proj import depth_to_xyz, apply_SE3


def gen_triangle_v_idx(H, W):
    pxl_idx = np.arange(H * W).reshape(H, W)
    triangle_v_idx = np.stack([
        np.stack([pxl_idx[:-1, :-1], pxl_idx[1:, :-1], pxl_idx[:-1, 1:]], axis=-1),  # (H - 1, W - 1, 3)
        np.stack([pxl_idx[1:, 1:], pxl_idx[:-1, 1:], pxl_idx[1:, :-1]], axis=-1),  # (H - 1, W - 1, 3)
    ], axis=-2)  # (H - 1, W - 1, 2, 3)
    return triangle_v_idx


def gen_trimesh_mesh(vs, cs, triangles):
    mesh_vertices = o3d.utility.Vector3dVector(vs.reshape(-1, 3))
    mesh_faces = o3d.utility.Vector3iVector(triangles)
    mesh = o3d.geometry.TriangleMesh(mesh_vertices, mesh_faces)
    mesh.compute_vertex_normals()

    trimesh_mesh = trimesh.Trimesh(
        vertices=np.asarray(mesh.vertices),
        faces=np.asarray(mesh.triangles),
        vertex_normals=np.asarray(mesh.vertex_normals),
        vertex_colors=cs.reshape(-1, 3),
        process=False
    )
    material = trimesh.visual.material.PBRMaterial(
        vertexColors=True,
        doubleSided=True
    )
    trimesh_mesh.visual.material = material
    return trimesh_mesh


def concatenate_mesh_data(mesh_datas):
    n = 0
    vs, cs, fs = [], [], []
    for v, c, f in mesh_datas:
        vs.append(v)
        cs.append(c)
        fs.append(f + n)
        n += v.shape[0]
    return np.concatenate(vs, axis=0), np.concatenate(cs, axis=0), np.concatenate(fs, axis=0)


def gen_mesh_data(intr, depth, depth_valid, SE3=np.eye(4), rgb=None, valid_triangle=None, crop_region=None):
    '''
    :param intr: shape (4,)
    :param depth: shape (H, W)
    :param SE3: shape (4, 4), points coords: apply_SE3(SE3, depth_to_xyz(intr, depth))
    :param rgb:
        if rgb.dtype == np.uint8:
            use rgb / 255
        else:
            assert rgb.dtype == np.float32
            use rgb
    :param valid_triangle:
    :param crop_region: [lb_i, ub_i, lb_j, ub_j]
    :return:
    '''
    depth = depth.astype(np.float32)
    SE3 = SE3.astype(np.float32)

    H, W = depth.shape

    if crop_region is not None and len(crop_region) > 0:
        lb_i, ub_i, lb_j, ub_j = crop_region
        region_valid = np.zeros_like(depth_valid)
        region_valid[lb_i:ub_i, lb_j:ub_j] = True
        depth_valid = depth_valid & region_valid

    xyz = apply_SE3(SE3, depth_to_xyz(intr, depth))

    # create triangles
    triangle_v_idx = gen_triangle_v_idx(H, W)

    # compute validity based on xyz validity
    valid_flattened = depth_valid.reshape(-1)
    xyz_flattened = xyz.reshape(-1, 3)
    valid_triangle_vertex = \
        valid_flattened[triangle_v_idx[..., 0]] & \
        valid_flattened[triangle_v_idx[..., 1]] & \
        valid_flattened[triangle_v_idx[..., 2]]  # (H - 1, W - 1, 2)
    if valid_triangle is None:
        valid_triangle = valid_triangle_vertex
    else:
        valid_triangle = valid_triangle_vertex & valid_triangle

    if rgb is None:
        vertex_colors = .7 * np.ones_like(xyz_flattened)
    else:
        if rgb.dtype == np.uint8:
            vertex_colors = rgb.reshape(-1, 3).astype(np.float32) / 255.
        else:
            assert rgb.dtype == np.float32
            vertex_colors = rgb.reshape(-1, 3)

    pxl_displayed = np.zeros((H, W), dtype=np.bool_)
    pxl_displayed[:-1, :-1] |= valid_triangle[..., 0]
    pxl_displayed[1:, :-1] |= valid_triangle[..., 0]
    pxl_displayed[:-1, 1:] |= valid_triangle[..., 0]
    pxl_displayed[1:, 1:] |= valid_triangle[..., 1]
    pxl_displayed[1:, :-1] |= valid_triangle[..., 1]
    pxl_displayed[:-1, 1:] |= valid_triangle[..., 1]
    invisible_to_display = depth_valid & (~pxl_displayed)

    def get_up_xyz(depth):
        fx, fy, cx, cy = intr[0], intr[1], intr[2], intr[3]
        v, u = np.meshgrid(np.arange(depth.shape[0]), np.arange(depth.shape[1]), indexing='ij')
        up_xyz = apply_SE3(SE3, np.stack([
            np.stack([((u - 1) - cx) / fx * depth, ((v - 1) - cy) / fy * depth, depth], axis=-1),
            np.stack([((u + 1) - cx) / fx * depth, ((v - 1) - cy) / fy * depth, depth], axis=-1),
            np.stack([((u - 1) - cx) / fx * depth, ((v + 1) - cy) / fy * depth, depth], axis=-1),
            np.stack([((u + 1) - cx) / fx * depth, ((v + 1) - cy) / fy * depth, depth], axis=-1),
        ], axis=-2).reshape(H, W, 2, 2, 3))
        return up_xyz

    depth_range = 1 / (.5 * (intr[0] + intr[1]))
    up_xyz_fnt = get_up_xyz((1 - depth_range) * depth)
    up_xyz_bck = get_up_xyz((1 + depth_range) * depth)

    up_xyz = np.stack([up_xyz_fnt, up_xyz_bck], axis=2).reshape(H, W, 8, 3)  # (H, W, 8, 3)
    up_vertex_idx = np.arange(H * W * 8).reshape(H, W, 8)
    up_triangles_to_stack = []
    for v1, v2, v3, v4 in [
        [0, 2, 3, 1],
        [0, 4, 6, 2],
        [0, 1, 5, 4],
        [7, 5, 1, 3],
        [7, 3, 2, 6],
        [7, 6, 4, 5],
    ]:
        up_triangles_to_stack.append(up_vertex_idx[..., [v1, v2, v3]])
        up_triangles_to_stack.append(up_vertex_idx[..., [v3, v4, v1]])
    up_triangles = np.stack(up_triangles_to_stack, axis=-2)  # (H, W, -1, 3)
    up_vertex_colors = np.repeat(vertex_colors.reshape(H, W, 1, 3), 8, axis=-2).reshape(-1, 3)

    xyz_flattened[~valid_flattened] = 0
    up_xyz[~depth_valid] = 0

    return [
        (xyz_flattened, vertex_colors, triangle_v_idx[valid_triangle]),
        (up_xyz.reshape(-1, 3), up_vertex_colors, up_triangles[invisible_to_display].reshape(-1, 3)),
    ]


def gen_mesh_and_pcd(intr, depth, depth_valid, SE3=np.eye(4), rgb=None, valid_triangle=None, crop_region=None):
    '''
    :param intr: shape (4,)
    :param depth: shape (H, W)
    :param SE3: shape (4, 4), points coords: apply_SE3(SE3, depth_to_xyz(intr, depth))
    :param rgb:
        if rgb.dtype == np.uint8:
            use rgb / 255
        else:
            assert rgb.dtype == np.float32
            use rgb
    :param valid_triangle:
    :param crop_region: [lb_i, ub_i, lb_j, ub_j]
    :return:
    '''
    depth = depth.astype(np.float32)
    SE3 = SE3.astype(np.float32)

    H, W = depth.shape

    if crop_region is not None and len(crop_region) > 0:
        lb_i, ub_i, lb_j, ub_j = crop_region
        region_valid = np.zeros_like(depth_valid)
        region_valid[lb_i:ub_i, lb_j:ub_j] = True
        depth_valid = depth_valid & region_valid

    xyz = apply_SE3(SE3, depth_to_xyz(intr, depth))

    # create triangles
    triangle_v_idx = gen_triangle_v_idx(H, W)

    # compute validity based on xyz validity
    valid_flattened = depth_valid.reshape(-1)
    xyz_flattened = xyz.reshape(-1, 3)
    valid_triangle_vertex = \
        valid_flattened[triangle_v_idx[..., 0]] & \
        valid_flattened[triangle_v_idx[..., 1]] & \
        valid_flattened[triangle_v_idx[..., 2]]  # (H - 1, W - 1, 2)
    if valid_triangle is None:
        valid_triangle = valid_triangle_vertex
    else:
        valid_triangle = valid_triangle_vertex & valid_triangle

    if rgb is None:
        vertex_colors = .7 * np.ones_like(xyz_flattened)
    else:
        if rgb.dtype == np.uint8:
            vertex_colors = rgb.reshape(-1, 3).astype(np.float32) / 255.
        else:
            assert rgb.dtype == np.float32
            vertex_colors = rgb.reshape(-1, 3)

    pxl_displayed = np.zeros((H, W), dtype=np.bool_)
    pxl_displayed[:-1, :-1] |= valid_triangle[..., 0]
    pxl_displayed[1:, :-1] |= valid_triangle[..., 0]
    pxl_displayed[:-1, 1:] |= valid_triangle[..., 0]
    pxl_displayed[1:, 1:] |= valid_triangle[..., 1]
    pxl_displayed[1:, :-1] |= valid_triangle[..., 1]
    pxl_displayed[:-1, 1:] |= valid_triangle[..., 1]
    invisible_to_display = depth_valid & (~pxl_displayed)

    def get_up_xyz(depth):
        fx, fy, cx, cy = intr[0], intr[1], intr[2], intr[3]
        v, u = np.meshgrid(np.arange(depth.shape[0]), np.arange(depth.shape[1]), indexing='ij')
        up_xyz = apply_SE3(SE3, np.stack([
            np.stack([((u - 1) - cx) / fx * depth, ((v - 1) - cy) / fy * depth, depth], axis=-1),
            np.stack([((u + 1) - cx) / fx * depth, ((v - 1) - cy) / fy * depth, depth], axis=-1),
            np.stack([((u - 1) - cx) / fx * depth, ((v + 1) - cy) / fy * depth, depth], axis=-1),
            np.stack([((u + 1) - cx) / fx * depth, ((v + 1) - cy) / fy * depth, depth], axis=-1),
        ], axis=-2).reshape(H, W, 2, 2, 3))
        return up_xyz

    depth_range = 1 / (.5 * (intr[0] + intr[1]))
    up_xyz_fnt = get_up_xyz((1 - depth_range) * depth)
    up_xyz_bck = get_up_xyz((1 + depth_range) * depth)

    up_xyz = np.stack([up_xyz_fnt, up_xyz_bck], axis=2).reshape(H, W, 8, 3)  # (H, W, 8, 3)
    up_vertex_idx = np.arange(H * W * 8).reshape(H, W, 8)
    up_triangles_to_stack = []
    for v1, v2, v3, v4 in [
        [0, 2, 3, 1],
        [0, 4, 6, 2],
        [0, 1, 5, 4],
        [7, 5, 1, 3],
        [7, 3, 2, 6],
        [7, 6, 4, 5],
    ]:
        up_triangles_to_stack.append(up_vertex_idx[..., [v1, v2, v3]])
        up_triangles_to_stack.append(up_vertex_idx[..., [v3, v4, v1]])
    up_triangles = np.stack(up_triangles_to_stack, axis=-2)  # (H, W, -1, 3)
    up_vertex_colors = np.repeat(vertex_colors.reshape(H, W, 1, 3), 8, axis=-2).reshape(-1, 3)

    xyz_flattened[~valid_flattened] = 0
    up_xyz[~depth_valid] = 0

    trimesh_mesh = gen_trimesh_mesh(*concatenate_mesh_data([
        (xyz_flattened, vertex_colors, triangle_v_idx[valid_triangle]),
        (up_xyz.reshape(-1, 3), up_vertex_colors, up_triangles[invisible_to_display].reshape(-1, 3))
    ]))

    pcd = gen_trimesh_mesh(up_xyz, up_vertex_colors, up_triangles[depth_valid].reshape(-1, 3))
    return trimesh_mesh, pcd


def gen_pole_mesh(above_ground=True, plane_size=2.0, pole_width=0.05, pole_height=0.7, SE3=np.eye(4), dist_to_cam=5):
    '''
    coords system: camera coords system, pole is along y-axis, plane is zx direction
    :param above_ground:
    :param plane_size:
    :param pole_width:
    :param pole_height:
    :return:
    '''
    pole_vs = np.array([
        [-.5 * pole_width, .5 * pole_height, -.5 * pole_width],
        [.5 * pole_width, .5 * pole_height, -.5 * pole_width],
        [.5 * pole_width, .5 * pole_height, .5 * pole_width],
        [-.5 * pole_width, .5 * pole_height, .5 * pole_width],
        [-.5 * pole_width, -.5 * pole_height, -.5 * pole_width],
        [.5 * pole_width, -.5 * pole_height, -.5 * pole_width],
        [.5 * pole_width, -.5 * pole_height, .5 * pole_width],
        [-.5 * pole_width, -.5 * pole_height, .5 * pole_width],
    ])

    def rect_face(v_ids):
        return [[v_ids[0], v_ids[1], v_ids[2]], [v_ids[2], v_ids[3], v_ids[0]]]

    pole_fs = []
    for rect_f_ids in [
        [0, 3, 2, 1],
        [0, 1, 5, 4],
        [1, 2, 6, 5],
        [2, 3, 7, 6],
        [3, 0, 4, 7],
        [4, 5, 6, 7]
    ]:
        pole_fs += rect_face(rect_f_ids)
    pole_fs = np.array(pole_fs)

    plane_height = (.5 * pole_height) if above_ground else (-.5 * pole_height)
    plane_vs = np.array([
        [-.5 * plane_size, plane_height, -.5 * plane_size],
        [.5 * plane_size, plane_height, -.5 * plane_size],
        [.5 * plane_size, plane_height, .5 * plane_size],
        [-.5 * plane_size, plane_height, .5 * plane_size],
    ])
    plane_fs = rect_face([0, 1, 2, 3])
    plane_fs = np.array(plane_fs)

    vs = np.concatenate([pole_vs, plane_vs], axis=0)
    if above_ground:
        vs[..., 1] += .5 * pole_height
    else:
        vs[..., 1] -= .5 * pole_height
    vs[..., 2] += 5
    vs *= dist_to_cam / 5
    vs = apply_SE3(SE3, vs)
    cs = .7 * np.ones_like(vs)
    fs = np.concatenate([pole_fs, plane_fs + pole_vs.shape[0]], axis=0)
    return gen_trimesh_mesh(vs, cs, fs)


def create_cam_frame_mesh(gt_depth, intr, out_f):
    import trimesh
    from trimesh.creation import cylinder

    def edge_cylinder(p1, p2, radius=0.0002, sections=16):
        """
        Create a cylinder mesh between points p1 and p2.
        """
        # direction vector
        vec = np.array(p2) - np.array(p1)
        length = np.linalg.norm(vec)
        if length == 0:
            return None

        # base cylinder (along z-axis, centered at origin)
        cyl = cylinder(radius=radius, height=length, sections=sections)

        # move cylinder so its base is at (0,0,0)
        cyl.apply_translation([0, 0, length / 2])

        # align cylinder with vec
        cyl.apply_transform(trimesh.geometry.align_vectors([0, 0, 1], vec))

        # move to position
        cyl.apply_translation(p1)

        return cyl

    cam_mesh_z = np.ones_like(gt_depth) * 3 * .5
    cam_mesh_xyz = depth_to_xyz(intr, cam_mesh_z)
    cam_mesh_xyz = np.stack([
        np.zeros(3),
        cam_mesh_xyz[0, 0],
        cam_mesh_xyz[0, -1],
        cam_mesh_xyz[-1, -1],
        cam_mesh_xyz[-1, 0],
    ], axis=0)
    cam_mesh_edges = [
        (0, 1),
        (0, 2),
        (0, 3),
        (0, 4),
        (1, 2),
        (2, 3),
        (3, 4),
        (4, 1),
    ]
    # Collect vertices and faces
    all_vertices = []
    all_faces = []
    vertex_offset = 0

    for u, v in cam_mesh_edges:
        cyl = edge_cylinder(cam_mesh_xyz[u], cam_mesh_xyz[v], radius=0.02)
        if cyl is None:
            continue

        # append vertices, faces, and colors
        all_vertices.append(cyl.vertices)
        all_faces.append(cyl.faces + vertex_offset)
        vertex_offset += len(cyl.vertices)

    # stack into arrays
    all_vertices = np.vstack(all_vertices) * .2
    all_faces = np.vstack(all_faces)
    all_colors = np.zeros_like(all_vertices)

    # build mesh
    mesh = gen_trimesh_mesh(all_vertices, all_colors, all_faces)

    out_f.parent.mkdir(parents=True, exist_ok=True)
    mesh.export(out_f)
    print(f'Saved to {out_f}')
