from __future__ import annotations

from typing import Sequence, Tuple
import jax.numpy as jnp
import numpy as np
import jax
import einops
from functools import partial

try:
    import open3d as o3d
except:
    print('open3d is not installed')

import util.transform_util as tutil
import util.structs as structs

# Typing
import numpy.typing as npt
IntrinsicT = Tuple[int, int, float, float, float, float]


def default_intrinsic(pixel_size:Sequence[int]) -> IntrinsicT:
    """Init default intrinsic from image.

    Args:
        pixel_size (Tuple[int, int]): (height, width).
    
    Returns:
        width (int): Image width
        height (int): Image height
        Fx (float): Focal length of x dimension
        Fy (float): Focal length of y dimension
        Cx (float): Center of x
        Cy (float): Center of y
    """
    return jnp.array([pixel_size[1], pixel_size[0], pixel_size[1], pixel_size[0], 0.5*pixel_size[1], 0.5*pixel_size[0]])

def pb_viewmatrix_to_cam_posquat(view_matrix: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Convert OpenGL view matrix to camera position and quaternion.

    Args:
        view_matrix (jnp.ndarray): (4, 4) view matrix
    
    Returns:
        Tuple[jnp.ndarray, jnp.ndarray]: (3,) camera position, (4,) camera quaternion
    """
    from scipy.spatial.transform import Rotation as sciR
    # cam_pos, cam_quat = tutil.H2pq(tutil.H_inv(np.array(view_matrix).reshape(4,4).T))
    # return jnp.concat([cam_pos, cam_quat], axis=-1)
    cam_H = np.linalg.inv(np.array(view_matrix).reshape(4,4).T)
    cam_pos = cam_H[...,:3,3]
    cam_quat = sciR.from_matrix(cam_H[...,:3,:3]).as_quat()
    return np.concatenate([cam_pos, cam_quat], axis=-1)

def pixel_ray(pixel_size:Sequence[int], cam_pos:jnp.ndarray, cam_quat:jnp.ndarray, 
                intrinsic:jnp.ndarray, near:float, far:float, coordinate:str='opengl'):
    '''
    bachable
    pixel_size : (2,) i j order
    cam_pos : (... 3) camera position
    cam_quat : (... 4) camera quaternion
    intrinsic : (... 6) camera intrinsic
    
    coordinate : 'opengl' or 'open3d' - oepngl: forward direction is z minus axis / open3d: forward direction is z plus

    return:
    ray start points, ray end points, ray directions
    '''
    cam_zeta = intrinsic[...,2:]
    zeros = jnp.zeros_like(cam_zeta[...,0])
    ones = jnp.ones_like(cam_zeta[...,0])
    # K_mat = jnp.stack([jnp.stack([cam_zeta[...,0], zeros, cam_zeta[...,2]],-1),
    #                     jnp.stack([zeros, cam_zeta[...,1], cam_zeta[...,3]],-1),
    #                     jnp.stack([zeros,zeros,ones],-1)],-2)
    K_mat = intrinsic_to_Kmat(intrinsic)
    # pixel= PVM (colomn-wise)
    # M : points
    # V : inv(cam_SE3)
    # P : Z projection and intrinsic matrix  
    x_grid_idx, y_grid_idx = jnp.meshgrid(jnp.arange(pixel_size[1])[::-1], jnp.arange(pixel_size[0])[::-1])
    pixel_pnts = jnp.concatenate([x_grid_idx[...,None], y_grid_idx[...,None], jnp.ones_like(y_grid_idx[...,None])], axis=-1)
    pixel_pnts = pixel_pnts.astype(jnp.float32)
    K_mat_inv = jnp.linalg.inv(K_mat)
    pixel_pnts = jnp.matmul(K_mat_inv[...,None,None,:,:], pixel_pnts[...,None])[...,0]
    if coordinate == 'opengl':
        # pixel_pnts = pixel_pnts.at[...,-1].set(-pixel_pnts[...,-1])
        pixel_pnts = jnp.c_[pixel_pnts[...,:-1], -pixel_pnts[...,-1:]]
        pixel_pnts = pixel_pnts[...,::-1,:]
    rays_s_canonical = pixel_pnts * near
    rays_e_canonical = pixel_pnts * far

    # cam SE3 transformation
    rays_s = tutil.pq_action(cam_pos[...,None,None,:], cam_quat[...,None,None,:], rays_s_canonical)
    rays_e = tutil.pq_action(cam_pos[...,None,None,:], cam_quat[...,None,None,:], rays_e_canonical)
    ray_dir = rays_e - rays_s
    ray_dir_normalized = tutil.normalize(ray_dir)

    return rays_s, rays_e, ray_dir_normalized


def pbfov_to_intrinsic(
        img_size: Tuple[int, int], 
        fov_deg: float
) -> IntrinsicT:
    """Calculate intrinsic components from fov.

    Assumes balanced focal length Fx=Fy.

    Args:
        img_size (Tuple[int, int]): (height, width).
        fov_degree (float): Vertical FOV in degree (y dimension, height)
    
    Returns:
        width (int): Image width
        height (int): Image height
        Fx (float): Focal length of x dimension
        Fy (float): Focal length of y dimension
        Cx (float): Center of x
        Cy (float): Center of y
    """
    fov_rad = fov_deg * np.pi/180.0
    Fy = img_size[0]*0.5/(np.tan(fov_rad*0.5))
    Fx = Fy
    Cx = img_size[1]*0.5
    Cy = img_size[0]*0.5
    return (img_size[1], img_size[0], Fx, Fy, Cx, Cy)


def intrinsic_to_fov(intrinsic: npt.NDArray):
    img_size_xy = intrinsic[...,:2]
    fovs = np.arctan(intrinsic[...,1]/intrinsic[...,3]*0.5)*2
    return fovs, img_size_xy[...,0] / img_size_xy[...,1]

def intrinsic_to_pb_lrbt(
        intrinsic: Sequence[IntrinsicT]|npt.NDArray, 
        near: float
) -> Tuple[float, float, float, float]:
    """TODO(ssh): ????

    Args:
        intrinsic (Sequence[IntrinsicT] | npt.NDArray): Intrinsic either in numpy or tuple...
        near (float): OpenGL near val

    Returns:
        Tuple[float, float, float, float]: ????
    """

    if isinstance(intrinsic, list) or isinstance(intrinsic, tuple):
        intrinsic = np.array(intrinsic)
    pixel_size = intrinsic[...,:2]
    fx = intrinsic[...,2]
    fy = intrinsic[...,3]
    cx = intrinsic[...,4]
    cy = intrinsic[...,5]
    
    halfx_px = pixel_size[...,0]*0.5
    center_px = cx - halfx_px
    right_px = center_px + halfx_px
    left_px = center_px - halfx_px

    halfy_px = pixel_size[...,1]*0.5
    center_py = cy - halfy_px
    bottom_px = center_py - halfy_px
    top_px = center_py + halfy_px

    return left_px/fx*near, right_px/fx*near, bottom_px/fy*near, top_px/fy*near


def intrinsic_to_Kmat(intrinsic):
    '''
    flip y direction - because of our wrong coordinate system...
    '''
    zeros = jnp.zeros_like(intrinsic[...,2])
    # return jnp.stack([jnp.stack([intrinsic[...,2], zeros, intrinsic[...,4]], -1),
    #             jnp.stack([zeros, intrinsic[...,3], intrinsic[...,5]], -1),
    #             jnp.stack([zeros, zeros, jnp.ones_like(intrinsic[...,2])], -1)], -2)
    return jnp.stack([jnp.stack([intrinsic[...,2], zeros, intrinsic[...,4]], -1),
                jnp.stack([zeros, intrinsic[...,3], intrinsic[...,1] - intrinsic[...,5]], -1),
                jnp.stack([zeros, zeros, jnp.ones_like(intrinsic[...,2])], -1)], -2)
    

def global_pnts_to_pixel(intrinsic, cam_posquat, pnts, expand=False):
    '''
    expand==False
        intrinsic, cam_posquat, pnts should have same dim

    expand==True
        intrinsic, cam_posquat : (... NR ...)
        pnts : (... NS 3)
        return - (... NR NS 2), out
    '''
    if not isinstance(cam_posquat, tuple):
        cam_posquat = (cam_posquat[...,:3], cam_posquat[...,3:])
    
    if expand:
        intrinsic, cam_posquat = jax.tree_map(lambda x: x[...,None,:], (intrinsic, cam_posquat))
        pnts = pnts[...,None,:,:]

    pixel_size = intrinsic[...,:2]
    pnt_img_pj = tutil.pq_action(*tutil.pq_inv(*cam_posquat), pnts) # (... NS NR 3)
    kmat = intrinsic_to_Kmat(intrinsic)
    px_coord_xy = jnp.einsum('...ij,...j', kmat[...,:2,:2], pnt_img_pj[...,:2]/(-pnt_img_pj[...,-1:])) + kmat[...,:2,2]
    out_pnts = jnp.logical_or(jnp.any(px_coord_xy<0, -1), jnp.any(px_coord_xy>=pixel_size, -1))
    px_coord_xy = jnp.clip(px_coord_xy, 0.001, pixel_size-0.001)
    px_coord_ij = jnp.stack([pixel_size[...,1]-px_coord_xy[...,1], px_coord_xy[...,0]], -1)
    return px_coord_ij, out_pnts
    # px_coord = px_coord.astype(jnp.float32)
    # px_coord = jnp.stack([-px_coord[...,1], px_coord[...,0]] , -1)# xy to ij

def cam_info_to_render_params(cam_info):
    cam_posquat, intrinsic = cam_info

    return dict(
        intrinsic=intrinsic,
        pixel_size=jnp.c_[intrinsic[...,1:2], intrinsic[...,0:1]].astype(jnp.int32),
        camera_pos=cam_posquat[...,:3],
        camera_quat=cam_posquat[...,3:],
    )


def pcd_from_depth(depth, intrinsic, pixel_size, cam_posquat=None, coordinate='opengl', visualize=False):
    if depth.shape[-1] != 1:
        depth = depth[...,None]

    xgrid, ygrid = np.meshgrid(np.arange(pixel_size[1]), np.arange(pixel_size[0]), indexing='xy')
    xygrid = np.stack([xgrid, ygrid], axis=-1)
    xy = (xygrid - intrinsic[...,None,None,4:6]) * depth / intrinsic[...,None,None,2:4]

    xyz = jnp.concatenate([xy, depth], axis=-1)

    if coordinate=='opengl':
       xyz =  jnp.stack([xyz[...,0], -xyz[...,1], -xyz[...,2]], axis=-1)
    
    if cam_posquat is not None:
        xyz = tutil.pq_action(cam_posquat[...,:3], cam_posquat[...,3:], xyz)

    if visualize:
        import open3d as o3d
        pcd_o3d = o3d.geometry.PointCloud()
        pcd_o3d.points = o3d.utility.Vector3dVector(xyz.reshape(-1,3))
        mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=0.1, origin=[0, 0, 0])
        o3d.visualization.draw_geometries([pcd_o3d, mesh_frame])

    return xyz

def resize_intrinsic(intrinsic, origin_pixel_size, output_pixel_size, base_axis='i', base_operator='jnp'):
    '''
    resize intrinsic parameters (pixelx, pixely, Fx, Fy, Cx, Cy) to match the output pixel size
    origin_pixel_size: [height, width] - ij convension
    output_pixel_size: [height, width] - ij convension
    '''

    if base_operator == 'np' or base_operator == 'numpy':
        bp = np
    else:
        bp = jnp
    intrinsic = bp.array(intrinsic)

    if base_axis == 'i' or base_axis == 0 or base_axis == 'height' or base_axis == 'vertical' or base_axis == 'y':
        scale_axis = 0
    else:
        scale_axis = 1
    

    scale = output_pixel_size[scale_axis]/origin_pixel_size[scale_axis]
    cam_intr_resized = intrinsic*scale

    pixel_bias = intrinsic[...,-2:] - bp.array(origin_pixel_size)[::-1]/2.0
    pixel_bias = bp.array(output_pixel_size)[::-1]/2.0 + pixel_bias*scale
    
    if base_operator == 'np' or base_operator == 'numpy':
        cam_intr_resized[...,:2] = np.array([output_pixel_size[1], output_pixel_size[0]])
        cam_intr_resized[...,-2:] = pixel_bias
    else:
        cam_intr_resized = cam_intr_resized.at[...,:2].set(np.array([output_pixel_size[1], output_pixel_size[0]]))
        cam_intr_resized = cam_intr_resized.at[...,-2:].set(pixel_bias)
    return cam_intr_resized

def create_o3d_cameras(cam_posquat, intrinsic, color=np.array((1.,0.,0.))):
    import open3d as o3d
    res_o3d_entities = []
    x, y = intrinsic[:2].astype(np.int32)
    Tm = tutil.pq2H(cam_posquat[:3], cam_posquat[3:])
    mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
        size=0.1, origin=[0, 0, 0])
    mesh_frame.transform(np.array(Tm))
    res_o3d_entities.append(mesh_frame)
    Tm = np.linalg.inv(Tm)
    intrinsic_matrix = intrinsic_to_Kmat(intrinsic)
    cam_lineset = o3d.geometry.LineSet.create_camera_visualization(view_width_px=x, view_height_px=y, 
                                                                        intrinsic=intrinsic_matrix, extrinsic=Tm, scale=0.1)
    cam_lineset.paint_uniform_color(color)
    res_o3d_entities.append(cam_lineset)
    return res_o3d_entities
    

def resize_img(img, output_pixel_size, base_axis='i', method='linear'):
    """
    Resize an image using jax.image.resize while preserving the aspect ratio.
    The scalar resize factor is computed based on the given base axis, then the
    non-base dimension is adjusted (cropped or padded) to exactly match the target.

    Parameters:
      img: jax array with shape (..., H, W, C)
      output_pixel_size: tuple (target_H, target_W)
      base_axis: which axis to base the scaling on. For vertical scaling use 'i', 0, 'height', 'vertical', or 'y';
                 for horizontal scaling use 'j', 1, 'width', 'horizontal', or 'x'.
      method: interpolation method for jax.image.resize (e.g. 'linear' for bilinear or 'nearest')

    Returns:
      Resized image with shape (..., target_H, target_W, C) in the original dtype.
    """
    orig_h, orig_w = img.shape[-3:-1]
    target_h, target_w = output_pixel_size

    # Determine base axis and compute scalar resize factor
    if base_axis in ['i', 0, 'height', 'vertical', 'y']:
        base_idx = 0  # vertical: height is the base dimension
        factor = target_h / orig_h
    elif base_axis in ['j', 1, 'width', 'horizontal', 'x']:
        base_idx = 1  # horizontal: width is the base dimension
        factor = target_w / orig_w
    else:
        raise ValueError("Invalid base_axis. Use 'i' (or height variants) or 'j' (or width variants).")

    # Calculate new size after scaling both dimensions uniformly
    new_h = int(round(orig_h * factor))
    new_w = int(round(orig_w * factor))

    # Convert image to floating point for interpolation if needed
    orig_dtype = img.dtype
    if not jnp.issubdtype(orig_dtype, jnp.floating):
        img = img.astype(jnp.float32)

    # Resize using jax.image.resize preserving any leading batch dimensions.
    new_shape = img.shape[:-3] + (new_h, new_w, img.shape[-1])
    resized = jax.image.resize(img, new_shape, method, antialias=False)

    # Determine the non-base axis (the one that might need cropping or padding)
    if base_idx == 0:
        # Base axis is height; non-base is width.
        current_non = new_w
        target_non = target_w
        axis = -2  # width is at position -2 in (..., H, W, C)
    else:
        # Base axis is width; non-base is height.
        current_non = new_h
        target_non = target_h
        axis = -3  # height is at position -3 in (..., H, W, C)

    # Adjust the non-base axis: crop if too large, pad if too small.
    if current_non > target_non:
        # Center-crop along the non-base axis.
        start = (current_non - target_non) // 2
        if axis == -2:
            # Crop width
            resized = resized[..., :, start:start+target_non, :]
        else:
            # Crop height
            resized = resized[..., start:start+target_non, :, :]
    elif current_non < target_non:
        # Compute the amount of padding needed along the non-base axis.
        pad_total = target_non - current_non
        pad_before = pad_total // 2
        pad_after = pad_total - pad_before

        # Build pad widths for all dimensions.
        # For any leading batch dimensions, no padding.
        pad_widths = [(0, 0)] * (len(resized.shape) - 3)
        # For height and width, pad only the non-base axis.
        if axis == -3:
            # Pad height: non-base axis is height.
            pad_widths.append((pad_before, pad_after))  # height
            pad_widths.append((0, 0))                  # width
        else:
            pad_widths.append((0, 0))                  # height
            pad_widths.append((pad_before, pad_after))  # width
        # For channel axis, no padding.
        pad_widths.append((0, 0))
        resized = jnp.pad(resized, pad_widths, mode='constant')

    # In case of any rounding issues, ensure the output has the exact target shape.
    final_shape = resized.shape[:-3] + (target_h, target_w, resized.shape[-1])
    resized = resized[..., :target_h, :target_w, :]

    return resized.astype(orig_dtype)


def resize_img_myself(img, output_pixel_size, method='linear', base_axis='i', base_operator='jnp'):
    '''
    img has shape (..., H, W, C)
    base_axis: 'i' or 'j' - i for height, j for width
    '''
    origin_size = img.shape[-3:-1]
    origin_type = img.dtype

    if all([i==j for i, j in zip(origin_size, output_pixel_size)]):
        return img

    if base_operator == 'np' or base_operator == 'numpy':
        bp = np
    else:
        bp = jnp

    if base_axis == 'i' or base_axis == 0 or base_axis == 'height' or base_axis == 'vertical' or base_axis == 'y':
        scale_axis = 0
    else:
        scale_axis = 1

    # generte grid
    grid = np.mgrid[0:output_pixel_size[0], 0:output_pixel_size[1]]

    # transform 2d grid points to the original image
    grid = np.moveaxis(grid, 0, -1)
    grid = grid.astype(np.float32)
    grid = grid - (np.array(output_pixel_size)-1)/2
    grid = grid/output_pixel_size[scale_axis]*origin_size[scale_axis]
    grid_float = grid + (np.array(origin_size)-1)/2
    grid_int = grid_float.round().astype(np.int32)
    grid_tmp = np.clip(grid_int, 0, np.array(origin_size)-1)
    out_mask = np.any(grid_int != grid_tmp, axis=-1)
    grid_float_clip = grid_float.clip(0, np.array(origin_size)-1)

    grid_residual_list = []
    img_edge_list = []
    for edge in np.array([[0, 0], [0, 1], [1, 0], [1, 1]]):
        grid_tmp = (grid_int + edge).clip(0, np.array(origin_size)-1)
        grid_residual = grid_float_clip - grid_tmp
        grid_residual_list.append(np.prod(np.abs(grid_residual), axis=-1, keepdims=True))
        grid_rs = grid_tmp.reshape(-1, 2)
        img_tmp = img[...,grid_rs[:, 0], grid_rs[:, 1],:]
        img_edge_list.append(img_tmp.reshape(*img_tmp.shape[:-2], output_pixel_size[0], output_pixel_size[1], img_tmp.shape[-1]))

    grid_residual_list = bp.concatenate(grid_residual_list, axis=-1)[...,np.array([3, 2, 1, 0])]
    img_edge_list = bp.stack(img_edge_list, axis=-2)
    if method == 'nearest':
        grid_residual_list = bp.argmin(grid_residual_list, axis=-1, keepdims=True)
        for _ in range(len(img_edge_list.shape) - len(grid_residual_list.shape) -1):
            grid_residual_list = grid_residual_list[None]
        img_out = bp.take_along_axis(img_edge_list, grid_residual_list[...,None], axis=-2).squeeze(-2)
    elif method == 'bilinear' or method == 'linear':
        grid_residual_list = grid_residual_list/(bp.sum(grid_residual_list, axis=-1, keepdims=True)+1e-7)
        img_out = bp.sum(img_edge_list*grid_residual_list[...,None], axis=-2)
    
    img_out = bp.where(out_mask[...,None], 0, img_out)
    return img_out.astype(origin_type)


def visualize_pcd(rgb_list_, depth_list_, cam_posquat_, intrinsic_, return_elements=False, pcd_o3d=None, tag_size=None, area=None):
    pixel_size_ = (intrinsic_[0,1], intrinsic_[0,0])
    intrinsic_matrix = intrinsic_to_Kmat(intrinsic_)
    pcd = pcd_from_depth(depth_list_, intrinsic_, pixel_size_, visualize=False)

    pcd_flat = pcd.reshape(cam_posquat_.shape[0],-1,3)

    pcd_tf = tutil.pq_action(cam_posquat_[...,None,:3], cam_posquat_[...,None,3:], pcd_flat)

    import open3d as o3d

    cam_vis_list = []
    cam_pq_mesh_list = []
    for i in range(cam_posquat_.shape[0]):
        x, y = intrinsic_[i,:2].astype(np.int32)
        Tm = tutil.pq2H(cam_posquat_[i,:3], cam_posquat_[i,3:])
        mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=0.1, origin=[0, 0, 0])
        mesh_frame.transform(np.array(Tm))
        cam_pq_mesh_list.append(mesh_frame)
        Tm = np.linalg.inv(Tm)
        cam_vis_list.append(o3d.geometry.LineSet.create_camera_visualization(view_width_px=x, view_height_px=y, 
                                                                            intrinsic=intrinsic_matrix[i], extrinsic=Tm, scale=0.1))

    pcd_tf = pcd_tf.reshape(-1,3)
    rgb_list_reshape = rgb_list_.reshape(-1,3)/255.
    if area is not None:
        valid_pcd_mask = np.logical_and(np.all(pcd_tf>area[0],-1), np.all(pcd_tf<area[1],-1))
        pcd_tf = pcd_tf[valid_pcd_mask]
        rgb_list_reshape = rgb_list_reshape[valid_pcd_mask]

    if pcd_o3d is None:
        pcd_o3d = o3d.geometry.PointCloud()
    pcd_o3d.points = o3d.utility.Vector3dVector(pcd_tf)
    pcd_o3d.colors = o3d.utility.Vector3dVector(rgb_list_reshape)
    mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
        size=0.1, origin=[0, 0, 0])
    
    tag_pnts = []
    if tag_size is not None:
        test_pnts = np.array([[0.,0.,0.],
                        [tag_size/2., tag_size/2., 0.],
                        [-tag_size/2., tag_size/2., 0.],
                        [tag_size/2., -tag_size/2., 0.],
                        [-tag_size/2., -tag_size/2., 0.],
                        [tag_size/2., 0., 0.],
                        [-tag_size/2., 0., 0.],
                        [0., tag_size/2., 0.],
                        [0., -tag_size/2., 0.],])
        for test_pnt_ in test_pnts:
            pnts_o3d = o3d.geometry.TriangleMesh.create_sphere(radius=0.010)
            pnts_o3d.compute_vertex_normals()
            pnts_o3d.paint_uniform_color(np.array([1.0,0,0]))
            pnts_o3d.translate(test_pnt_)
            tag_pnts.append(pnts_o3d)

    ## o3d point projection test
    # o3d_intrinsic = o3d.camera.PinholeCameraIntrinsic()
    # o3d_intrinsic.set_intrinsics(int(intrinsic_[0,0]), int(intrinsic_[0,1]), intrinsic_[0,2], intrinsic_[0,3], intrinsic_[0,4], intrinsic_[0,5])
    # o3d_depth = o3d.geometry.Image((depth_list_[0]*1000).astype(np.uint16))
    # pcd_from_o3d = o3d.geometry.PointCloud.create_from_depth_image(o3d_depth, o3d_intrinsic, depth_scale=1000.0, depth_trunc=1000.0, stride=1)

    if not return_elements:
        o3d.visualization.draw_geometries([*cam_vis_list, *cam_pq_mesh_list, pcd_o3d, mesh_frame, *tag_pnts])
    else:
        return [*cam_vis_list, *cam_pq_mesh_list, pcd_o3d, mesh_frame, *tag_pnts], pcd_tf, rgb_list_reshape

def partial_pcd_from_depth(depth, intrinsic, pixel_size, coordinate='opengl', visualize=False):
    if depth.shape[-1] != 1:
        depth = depth[...,None]

    # depth (#cam, w, h, 1)

    xgrid, ygrid = np.meshgrid(np.arange(pixel_size[1]), np.arange(pixel_size[0]), indexing='xy')
    xygrid = np.stack([xgrid, ygrid], axis=-1)
    xy = (xygrid - intrinsic[...,None,None,4:6]) * depth / intrinsic[...,None,None,2:4]
    # xy (#cam, w, h, #particle)

    xyz = jnp.concatenate([xy, depth], axis=-1)
    if coordinate=='opengl':
       xyz = jnp.stack([xyz[...,0], -xyz[...,1], -xyz[...,2]], axis=-1)

    if visualize:
        import open3d as o3d
        pcd_o3d = o3d.geometry.PointCloud()
        pcd_o3d.points = o3d.utility.Vector3dVector(xyz.reshape(-1,3))
        mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=0.1, origin=[0, 0, 0])
        o3d.visualization.draw_geometries([pcd_o3d, mesh_frame])

    return xyz

def gen_partial_pcd(depth_list_, seg_boolean, cam_posquat_, intrinsic_, o3dvis=None):
    pixel_size_ = (intrinsic_[0,1], intrinsic_[0,0])
    intrinsic_matrix = intrinsic_to_Kmat(intrinsic_)
    pcd = partial_pcd_from_depth(depth_list_, intrinsic_, pixel_size_, visualize=False)

    # pcd (#cam, w, h, 3 (xyz))
    # seg_boolean (#cam, w, h, #particle)
    total_partial_pcd = []
    for particle_num in range(seg_boolean.shape[-1]):

        seg_boolean_mask = seg_boolean[..., particle_num].reshape(-1) # (#cam, w, h) -> (-1)

        pcd_flat = pcd.reshape(cam_posquat_.shape[0],-1,3)

        pcd_tf = tutil.pq_action(cam_posquat_[...,None,:3], cam_posquat_[...,None,3:], pcd_flat)
        pcd_tf = pcd_tf.reshape(-1,3)
        partial_pcd = []
        for i in range(pcd_tf.shape[0]):
            if seg_boolean_mask[i]: partial_pcd.append(pcd_tf[i, :])

        partial_pcd = np.array(partial_pcd)
        total_partial_pcd.append(partial_pcd)

        import open3d as o3d

        cam_vis_list = []
        cam_pq_mesh_list = []
        for i in range(cam_posquat_.shape[0]):
            x, y = intrinsic_[i,:2].astype(np.int32)
            Tm = tutil.pq2H(cam_posquat_[i,:3], cam_posquat_[i,3:])
            mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
                size=0.1, origin=[0, 0, 0])
            mesh_frame.transform(np.array(Tm))
            cam_pq_mesh_list.append(mesh_frame)
            Tm = np.linalg.inv(Tm)
            cam_vis_list.append(o3d.geometry.LineSet.create_camera_visualization(view_width_px=x, view_height_px=y, 
                                                                                intrinsic=intrinsic_matrix[i], extrinsic=Tm, scale=0.1))

        pcd_o3d = o3d.geometry.PointCloud()
        pcd_o3d.points = o3d.utility.Vector3dVector(partial_pcd)
        mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(
            size=0.1, origin=[0, 0, 0])
        if o3dvis is None:
            o3d.visualization.draw_geometries([*cam_vis_list, *cam_pq_mesh_list, pcd_o3d, mesh_frame])
        else:
            for o in [*cam_vis_list, *cam_pq_mesh_list, pcd_o3d, mesh_frame]:
                o3dvis.update_geometry(o)
            o3dvis.poll_events()
            o3dvis.update_renderer()
    
    return total_partial_pcd



def take_img_features(img_ft, intrinsic, cam_posquat, qpnts):
    '''
    input
    img_ft : (... NI NJ NF)
    intrinsic : (..., 7)
    cam posquat : (..., 3), (..., 4)
    qpnts : (..., 3)
    '''
    img_ft_size = jnp.array(img_ft.shape[-3:-1])
    input_img_size_ = intrinsic[...,:2]
    px_coord_ctn, out_pnts_indicator = global_pnts_to_pixel(intrinsic, cam_posquat, qpnts) # (... 2)
    px_coord_ctn = px_coord_ctn/input_img_size_ * jnp.array(img_ft_size).astype(jnp.float32)

    px_coord = (px_coord_ctn-0.5).astype(jnp.int32) # centering # (... 2)

    # apply bilinear interpolation
    min_pnt = px_coord.astype(jnp.float32) + 0.5
    bound_pnt = min_pnt[...,None,:] + jnp.array([[0,0], [0,1], [1,0], [1,1]], dtype=jnp.float32)
    bound_pnt = bound_pnt.clip(0.5, img_ft_size-0.5)
    px_coord_residual = jnp.abs(px_coord_ctn[...,None,:].clip(0.5, img_ft_size.astype(jnp.float32)-0.5) - bound_pnt)[...,(3,2,1,0),:]
    px_coord_ratio = jnp.prod(px_coord_residual+1e-2, axis=-1)
    px_coord_ratio = px_coord_ratio/jnp.sum(px_coord_ratio, axis=-1, keepdims=True).clip(1e-4)
    px_coord_ratio = jax.lax.stop_gradient(px_coord_ratio)
    # assert jnp.sum(jnp.sum(px_coord_ratio, axis=-1) < 1-1e-5) == 0

    px_coord = px_coord.clip(0, img_ft_size-1).astype(jnp.int32)
    px_coord_bound = px_coord[...,None,:] + jnp.array([[0,0], [0,1], [1,0], [1,1]], dtype=jnp.int32)
    px_coord_bound = px_coord_bound.clip(0, img_ft_size-1)

    px_flat_idx = px_coord_bound[...,1] + px_coord_bound[...,0] * img_ft_size[...,1] # (... 4)

    img_fts_flat = einops.rearrange(img_ft, '... i j k -> ... (i j) k')
    selected_img_fts = jnp.take_along_axis(img_fts_flat, px_flat_idx[...,None], axis=-2) # (... 4 NF)
    img_ft_res = jnp.sum(px_coord_ratio[...,None] * selected_img_fts, axis=-2)
    img_ft_res = jnp.concatenate([out_pnts_indicator[...,None].astype(jnp.float32), img_ft_res], axis=-1)

    return img_ft_res

def resize_rgb_and_cam_info(rgb, cam_info, output_pixel_size):
    '''
    input : (NB, ...) or (NB, NC ...)
    '''
    if rgb.ndim == 4:
        rgb, cam_info = jax.tree_map(lambda x: x[None], (rgb, cam_info))

    intrinsic_rs = jax.vmap(partial(resize_intrinsic, origin_pixel_size=rgb.shape[-3:-1], output_pixel_size=output_pixel_size))(cam_info[1])
    cam_info_init_rs = (cam_info[0], intrinsic_rs)
    rgb_rs = jax.vmap(partial(resize_img, output_pixel_size=output_pixel_size))(rgb)

    return rgb_rs, cam_info_init_rs

def default_cond_feat(pixel_size)->structs.ImgFeatures:
    intrinsic = default_intrinsic(pixel_size)
    cam_pos = jnp.array([0,0,-1])
    cam_quat = tutil.aa2q(jnp.array([0,0,0.]))
    cam_posquat = jnp.concatenate([cam_pos, cam_quat], -1)
    return structs.ImgFeatures(intrinsic[None], cam_posquat[None], None)


def generate_random_camera_views(num_views, view_target, view_dist, cam_base_direction=np.array([1,0,0])):
    """
    Generate random camera viewpoints around a target point on a semi-sphere.
    
    Args:
        num_views (int): Number of camera views to generate.
        view_target (np.ndarray): The point the cameras are looking at (target).
        view_dist (float): Distance of the cameras from the target.
        cam_base_direction (np.ndarray): Base direction for the camera positions (e.g., along the x-axis).
        
    Returns:
        np.ndarray: Array of camera position quaternions (shape [num_views, 7]).
        np.ndarray: Array of camera intrinsic parameters (shape [num_views, 6]).
    """
    cam_posquats = []
    cam_intrinsics = []
    
    for i in range(num_views):
        # Randomize camera distance slightly for variance
        dist = view_dist + np.random.uniform(-0.1, 0.1)

        # Generate a random angle around the target point (semi-sphere)
        yaw = np.random.uniform(-60, 60)  # Horizontal angle from base direction
        pitch = np.random.uniform(-30, 30)  # Vertical angle
        
        # Compute camera position using spherical coordinates
        cam_pos = p.getMatrixFromQuaternion(p.getQuaternionFromEuler([pitch, yaw, 0]))
        cam_pos = np.dot(cam_pos, cam_base_direction) * dist + view_target
        
        # Get the view matrix, and extract position + quaternion from it
        view_matrix = p.computeViewMatrix(cam_pos, view_target, [0, 0, 1])  # Looking at the target
        cam_posquat = cutil.pb_viewmatrix_to_cam_posquat(view_matrix)
        cam_posquats.append(cam_posquat)

        # Camera intrinsics (assuming 640x480 resolution)
        cam_intrinsic = [640, 480, 320, 320, 640, 480]  # Example intrinsic: (W, H, Fx, Fy, Cx, Cy)
        cam_intrinsics.append(cam_intrinsic)

    cam_posquats = np.array(cam_posquats)  # Shape: [num_views, 7]
    cam_intrinsics = np.array(cam_intrinsics)  # Shape: [num_views, 6]

    return cam_posquats, cam_intrinsics

from scipy.spatial.transform import Rotation as R
def convert_opengl_to_open3d(intrinsics_gl: np.ndarray, cam_posquat: np.ndarray):
    """
    Convert camera intrinsics and camera pose from an OpenGL-style frame to the Open3D frame.

    Parameters
    ----------
    intrinsics_gl : np.ndarray, shape (6,)
        Camera intrinsics in the OpenGL format:
        [px, py, fx, fy, cx, cy]
        where:
            - px, py: image dimensions (width, height)
            - fx, fy: focal lengths
            - cx, cy: principal point (origin at lower–left)
    cam_posquat : np.ndarray, shape (7,)
        Camera pose in the OpenGL format:
        [tx, ty, tz, qw, qx, qy, qz]
        where:
            - (tx, ty, tz) is the camera position in world coordinates.
            - (qw, qx, qy, qz) is the camera rotation as a quaternion (in [w, x, y, z] order).
        In OpenGL the camera is assumed to have its positive z–axis pointing "backward".

    Returns
    -------
    intrinsics_o3d : np.ndarray, shape (6,)
        The converted camera intrinsics for Open3D.
        The only change is that the principal point y–coordinate is flipped:
            new cy = py - cy.
    cam_posquat_o3d : np.ndarray, shape (7,)
        The converted camera pose for Open3D.
        The first three values are the position (converted by flipping y and z),
        and the last four values are the quaternion (in [w, x, y, z] order)
        corresponding to a camera that "looks" along +z.
    """
    # ----------------------------
    # Unpack intrinsics
    # Here we assume:
    #   px = image width, py = image height,
    #   fx, fy = focal lengths,
    #   cx, cy = principal point (in OpenGL, with origin at bottom–left)
    px, py, fx, fy, cx, cy = intrinsics_gl

    # Flip the y coordinate of the principal point.
    # (Assuming py is the image height.)
    cy_o3d = py - cy
    intrinsics_o3d = np.array([px, py, fx, fy, cx, cy_o3d])

    # ----------------------------
    # Extract camera position and quaternion from the input pose.
    t_gl = cam_posquat[:3]       # camera position: (tx, ty, tz)
    q_gl = cam_posquat[3:]       # quaternion in [w, x, y, z] order

    # Convert quaternion from [w, x, y, z] to SciPy's [x, y, z, w] order.
    q_gl_scipy = np.concatenate((q_gl[1:], q_gl[:1]))

    # Get the rotation matrix corresponding to the OpenGL quaternion.
    R_gl = R.from_quat(q_gl_scipy).as_matrix()

    # ----------------------------
    # Define the conversion (flip) matrices.
    # T_world converts from the OpenGL world frame to the Open3D world frame.
    T_world = np.diag([1, -1, -1])
    # T_cam converts from OpenGL camera coordinates (points in front have negative z)
    # to Open3D camera coordinates (points in front have positive z).
    T_cam = np.diag([1, 1, -1])

    # Convert the rotation:
    R_o3d = T_world @ R_gl @ T_cam

    # Convert the rotation matrix back to a quaternion.
    # SciPy returns quaternion in [x, y, z, w] order.
    q_o3d_scipy = R.from_matrix(R_o3d).as_quat()
    # Convert back to [w, x, y, z] order.
    q_o3d = np.concatenate((q_o3d_scipy[-1:], q_o3d_scipy[:-1]))

    # Convert the translation.
    t_o3d = T_world @ t_gl

    # Concatenate the translation and quaternion into one (7,) array.
    cam_posquat_o3d = np.concatenate((t_o3d, q_o3d))

    return intrinsics_o3d, cam_posquat_o3d


from scipy.spatial.transform import Rotation as R
def np2o3d_img2pcd(
        color: np.ndarray,
        depth: np.ndarray,
        intrinsic_gl: np.ndarray,
        cam_pq_gl: np.ndarray,
        normal: bool = True,
        depth_max: float = 3.0
):
    """
    Convert color and depth images along with camera intrinsics and extrinsics in OpenGL 
    convention to an Open3D point cloud.
    
    Args:
        color: color image as a NumPy array.
        depth: depth image as a NumPy array.
        intrinsic_gl: camera intrinsics in OpenGL format, a (6,) array with
                      [px, py, fx, fy, cx, cy]
                      where px and py are the image width and height,
                      and (cx, cy) is the principal point (origin at bottom–left).
        cam_pq_gl: camera pose in OpenGL format, a (7,) array with
                   [tx, ty, tz, qw, qx, qy, qz]
                   where (tx, ty, tz) is the camera position in world coordinates,
                   and the quaternion is in [w, x, y, z] order.
        normal: whether to compute normals.
        depth_max: maximum depth to consider.
    
    Returns:
        A legacy Open3D PointCloud.
    """
    # --------------------------------------------------------------------
    # Convert OpenGL intrinsics to Open3D intrinsics.
    # Unpack: (px, py, fx, fy, cx, cy)
    px, py, fx, fy, cx, cy = intrinsic_gl
    # OpenGL defines the principal point with origin at bottom–left.
    # Flip the y–coordinate (using py as the image height).
    # cy_o3d = py - cy
    cy_o3d = cy
    # Create the intrinsic (camera) matrix:
    #       [ fx   0   cx ]
    #   K = [  0  fy  cy_o3d ]
    #       [  0   0    1   ]
    K_np = np.array([
        [fx, 0,   cx],
        [0,  fy,  cy_o3d],
        [0,  0,   1]
    ], dtype=np.float32)
    # Open3D's t-geometry routines expect K as an o3d.core.Tensor.
    K = o3d.core.Tensor(K_np, dtype=o3d.core.Dtype.Float32)
    
    # --------------------------------------------------------------------
    # Convert OpenGL extrinsics to Open3D extrinsics.
    # cam_pq_gl: [tx, ty, tz, qw, qx, qy, qz]
    t_gl = cam_pq_gl[:3]
    q_gl = cam_pq_gl[3:]
    # Convert quaternion from [w, x, y, z] to SciPy's [x, y, z, w] order.
    # q_gl_scipy = np.concatenate((q_gl[1:], q_gl[:1]))
    q_gl_scipy = q_gl
    R_gl = R.from_quat(q_gl_scipy).as_matrix()
    
    # Define the flip matrices:
    # T_world = np.diag([1, -1, -1])  # flips y and z in the world coordinates.
    # T_cam   = np.diag([1, 1, -1])    # flips z in the camera coordinates.
    T_cam   = np.diag([1, -1, -1])    # flips z in the camera coordinates.
    
    # Apply the conversion:
    # R_o3d = T_world @ R_gl @ T_cam
    # t_o3d = T_world @ t_gl
    R_o3d = R_gl @ T_cam
    t_o3d = t_gl
    
    # Construct the 4x4 transformation matrix (extrinsics) from camera to world.
    T_cb = np.eye(4, dtype=np.float32)
    T_cb[:3, :3] = R_o3d
    T_cb[:3, 3]  = t_o3d
    T_cb = np.linalg.inv(T_cb)
    # Convert to an Open3D tensor.
    T_cb_tensor = o3d.core.Tensor.from_numpy(T_cb)
    
    # --------------------------------------------------------------------
    # Create an RGBD image and generate the point cloud.
    rgbd = o3d.t.geometry.RGBDImage(
        o3d.t.geometry.Image(color.astype(np.uint8)),
        o3d.t.geometry.Image(depth.astype(np.float32))
    )
    
    pcd = o3d.t.geometry.PointCloud.create_from_rgbd_image(
        rgbd,
        K,
        depth_scale=1.0,    # adjust if your depth values need scaling
        with_normals=normal,
        extrinsics=T_cb_tensor,
        depth_max=depth_max
    )
    
    # Return a legacy Open3D point cloud.
    return pcd.to_legacy()