import pybullet as p
import pybullet_data
import numpy as np
import open3d as o3d
import time
import os, sys
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scipy.spatial.transform import Rotation as sciR

BASEDIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASEDIR not in sys.path:
    sys.path.insert(0, BASEDIR)

import util.camera_util as cutil
import modules.encoder as encoder

# --------------------------------------------------------------------------
# Utility Functions
# --------------------------------------------------------------------------

def setup_pybullet(obj_path):
    """
    Connect to PyBullet, load a ground plane and the OBJ file as a mesh.
    """
    # Connect to PyBullet (GUI mode so you can see the simulation)
    p.connect(p.GUI)
    p.setGravity(0, 0, -9.81)
    # Set additional search path for pybullet_data (contains plane.urdf, etc.)
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    
    # Load a ground plane for reference
    plane_id = p.loadURDF("plane.urdf")
    
    # Create a collision shape and visual shape from the OBJ file.
    collision_shape_id = p.createCollisionShape(shapeType=p.GEOM_MESH,
                                                fileName=obj_path,
                                                meshScale=[2, 2, 2])
    visual_shape_id = p.createVisualShape(shapeType=p.GEOM_MESH,
                                          fileName=obj_path,
                                          meshScale=[2, 2, 2])
    # Create a multibody for the object (static object: mass=0)
    obj_id = p.createMultiBody(baseMass=0,
                               baseCollisionShapeIndex=collision_shape_id,
                               baseVisualShapeIndex=visual_shape_id,
                               basePosition=[0, 0, 0],
                               baseOrientation=sciR.from_euler('x', np.pi/2).as_quat())
    return obj_id

def random_camera_pose(radius=1.0):
    """
    Generates a random camera position on a sphere around the origin.
    The camera always looks at the origin.
    """
    # theta = np.random.uniform(0, 2*np.pi)
    # theta = np.random.uniform(-np.pi/4-np.pi/2, np.pi/4-np.pi/2)
    theta = np.random.uniform(-np.pi/2-np.pi/2, np.pi/2-np.pi/2)
    phi = np.random.uniform(np.pi/6, np.pi/2+np.pi/6)  # keep camera elevated
    x = radius * np.sin(phi) * np.cos(theta)
    y = radius * np.sin(phi) * np.sin(theta)
    z = radius * np.cos(phi) + 0.5
    cam_pos = np.array([x, y, z])
    target = np.array([0, 0, 0.5])  # Look at the object center
    up = np.array([0, 0, 1])
    view_matrix = p.computeViewMatrix(cameraEyePosition=cam_pos,
                                      cameraTargetPosition=target,
                                      cameraUpVector=up)
    return cam_pos, view_matrix

def get_projection_matrix(width, height, fov=60, near=0.1, far=3.1):
    """
    Computes the projection matrix using the field-of-view (FOV).
    """
    aspect = width / height
    projection_matrix = p.computeProjectionMatrixFOV(fov=fov,
                                                     aspect=aspect,
                                                     nearVal=near,
                                                     farVal=far)
    return projection_matrix

def capture_depth_image(view_matrix, projection_matrix, width=640, height=480):
    """
    Captures an image from the PyBullet simulation.
    Returns RGB, depth, and segmentation images.
    """
    img_arr = p.getCameraImage(width=width,
                               height=height,
                               viewMatrix=view_matrix,
                               projectionMatrix=projection_matrix,
                            #    renderer=p.ER_BULLET_HARDWARE_OPENGL,
                               )
    # The returned tuple contains multiple elements.
    # Here, we reshape the RGBA, depth buffer and segmentation mask.
    rgb = np.reshape(img_arr[2], (height, width, 4))[..., :3]
    depth_buffer = np.reshape(img_arr[3], (height, width))
    seg = np.reshape(img_arr[4], (height, width))
    
    # Convert the depth buffer to actual depth values.
    # Formula from PyBullet documentation:
    # depth = far * near / (far - (far - near) * depth_buffer)
    near = 0.1
    far = 3.1
    depth = far * near / (far - (far - near) * depth_buffer)
    return rgb, depth

def depth_to_point_cloud(depth, intrinsic, view_matrix=None):
    """
    Converts a depth image (in meters) to an Open3D point cloud.
    Open3D expects a depth image in mm (as uint16) by default.
    """
    # Scale depth to mm and convert to uint16.
    depth_mm = (depth * 1000).astype(np.uint16)
    depth_o3d = o3d.geometry.Image(depth_mm)
    # Create a point cloud from the depth image.
    pcd = o3d.geometry.PointCloud.create_from_depth_image(depth_o3d, intrinsic)
    if view_matrix is not None:
        pcd.transform(np.linalg.inv(np.array(view_matrix).reshape(4, 4).T))
    return pcd

def compute_ray_direction(u, v, fx, fy, cx, cy):
    """
    Computes a normalized ray direction from camera center given pixel coordinates.
    """
    x = (u - cx) / fx
    y = (v - cy) / fy
    ray_dir = np.array([x, y, 1.0])
    ray_dir /= np.linalg.norm(ray_dir)
    return ray_dir

def generate_tsdf_for_pixel(depth_img, intrinsic,
                            max_depth=3.0, num_samples=50, truncation=0.1):
    """
    For a given pixel, sample points along the camera ray and compute the TSDF values.
    The signed distance is defined as the difference between the distance along the ray
    and the measured depth, and is truncated by the threshold.
    """
    pixel_u, pixel_v, fx, fy, cx, cy = intrinsic
    pixel_u = int(pixel_u)
    pixel_v = int(pixel_v)
    measured_depth = depth_img[pixel_v, pixel_u]
    ray_direction = compute_ray_direction(pixel_u, pixel_v, fx, fy, cx, cy)
    ray_origin = np.array([0, 0, 0])  # Camera coordinate system origin
    t_vals = np.linspace(0, max_depth, num_samples)
    query_points = ray_origin[None, :] + t_vals[:, None] * ray_direction[None, :]
    sdf_values = t_vals - measured_depth
    sdf_values = np.clip(sdf_values, -truncation, truncation)
    return query_points, sdf_values

# --------------------------------------------------------------------------
# Main Execution
# --------------------------------------------------------------------------


def sample_pb_scene(nviews):

    # Number of camera views to capture.
    
    # Path to the OBJ file (adjust this path to your file).
    obj_path = '/home/dongwon/research/object_set/own_assets/shelf/modified/shelf-045.obj'
    
    # Initialize PyBullet and load the object.
    obj_id = setup_pybullet(obj_path)
    
    # Image parameters.
    rgbs = []
    depths = []
    cam_pqc = []
    intrinsics = []
    width, height = 424, 240
    fov = 60   # Field-of-view in degrees.
    near = 0.1
    far = 3.1
    projection_matrix = get_projection_matrix(width, height, fov, near, far)
    for i in range(nviews):
        # Generate a random camera pose.
        cam_pos, view_matrix = random_camera_pose(radius=1.3)

        # Capture the image from PyBullet.
        rgb, depth = capture_depth_image(view_matrix, projection_matrix, width, height)

        # plt.figure()
        # plt.imshow(depth)
        # plt.show()

        # Convert the depth image to a point cloud using Open3D.
        cam_posquat = cutil.pb_viewmatrix_to_cam_posquat(view_matrix)
        intrinsic = cutil.pbfov_to_intrinsic((height , width), fov)
        intrinsic = np.array(intrinsic)
        cam_pqc.append(cam_posquat)
        intrinsics.append(intrinsic)
        rgbs.append(rgb)
        depths.append(depth)
    p.disconnect()
    return depths, cam_pqc, intrinsics, rgbs


def gen_tsdf_data(depths, cam_pqc, intrinsics, rgbs=None, depth_max=2.0, visualize=False):
    
    nviews = len(depths)
    height, width = depths[0].shape[:2]
    # Lists to hold point clouds and TSDF data from each view.
    all_point_clouds = []
    all_tsdf_data = []
    for i in range(nviews):
        rgb, depth, cam_posquat, intrinsic = rgbs[i], depths[i], cam_pqc[i], intrinsics[i]
        surface_pnts = cutil.pcd_from_depth(depth, intrinsic, (height, width), cam_posquat)
        surface_pnts = surface_pnts.reshape(-1, 3)

        # pcd = cutil.np2o3d_img2pcd(rgb, depth, np.array(intrinsic), np.array(cam_posquat), normal=False, depth_max=10)
        # all_point_clouds.append(pcd)
        # surface_pnts = np.asarray(pcd.points)
        # surface_pnts = surface_pnts.reshape(height, width, 3)
        # surface_pnts = surface_pnts.transpose(1, 0, 2).reshape(-1, 3)
        
        # For TSDF generation, demonstrate with the center pixel.
        pixel_rays_start, _, pixel_rays_direction = cutil.pixel_ray((height, width), cam_posquat[:3], cam_posquat[3:], intrinsic, near=0.05, far=3.1, coordinate='opengl')
        pixel_rays_start = pixel_rays_start.reshape(-1, 3)
        pixel_rays_direction = pixel_rays_direction.reshape(-1, 3)
        assert surface_pnts.shape[0] == height * width
        valid_pixel = np.logical_and(depth.reshape(-1) > 0.05, depth.reshape(-1) < depth_max)
        # invalid_pixel = np.logical_not(valid_pixel)
        # valid_rays_start = pixel_rays_start.reshape(-1,3)[valid_pixel]
        valid_rays_direction = pixel_rays_direction.reshape(-1,3)[valid_pixel]
        valid_depth = depth.reshape(-1)[valid_pixel][...,None]
        valid_surface_points = surface_pnts[valid_pixel]
        # positive_depth1 = np.random.uniform(-valid_depth+0.1, 0, size=valid_depth.shape[:-1] + (16,))
        # positive_depth2 = np.random.uniform(-0.030, 0, size=valid_depth.shape[:-1] + (16,))
        # positive_depth = np.concatenate([positive_depth1, positive_depth2], axis=-1)
        positive_depth = np.random.uniform(-0.100, 0, size=valid_depth.shape[:-1] + (32,)) # outside
        negative_depth = np.random.uniform(0, 0.020, size=valid_depth.shape[:-1] + (32,)) # inside
        entire_depth = np.concatenate([positive_depth, negative_depth], axis=-1)

        # invalid_rays_start = pixel_rays_start.reshape(-1,3)[invalid_pixel]
        # invalid_rays_direction = pixel_rays_direction.reshape(-1,3)[invalid_pixel]

        query_points = valid_surface_points[...,None,:] + valid_rays_direction[...,None,:] * entire_depth[...,None]
        sdf_values = np.concatenate([np.ones_like(positive_depth), -np.ones_like(negative_depth)], axis=-1)

        # query_ray_points = np.concatenate([valid_rays_start, invalid_rays_start], axis=0)
        # query_ray_directions = np.concatenate([valid_rays_direction, invalid_rays_direction], axis=0)
        # ray_hit_values = np.concatenate([np.ones_like(valid_rays_start[...,0]), np.zeros_like(invalid_rays_start[...,0])], axis=0)

        # query_ray_points = np.concatenate([pixel_rays_start[valid_pixel], pixel_rays_start], axis=-2)
        # query_ray_directions = np.concatenate([surface_pnts[valid_pixel]-pixel_rays_start[valid_pixel] + 0.03*pixel_rays_direction[valid_pixel],
        #                                        surface_pnts-pixel_rays_start - 0.03*pixel_rays_direction], axis=0)
        # ray_hit_values = np.concatenate([np.ones(valid_depth.shape[:-1]).astype(np.float32), -np.ones_like(pixel_rays_start[...,0]).astype(np.float32)], axis=0)


        query_ray_points = pixel_rays_start
        query_ray_directions = surface_pnts-pixel_rays_start - 0.05*pixel_rays_direction
        ray_hit_values = -np.ones_like(pixel_rays_start[...,0]).astype(np.float32)


        cur_train_data = {
            'query_points': query_points.reshape(-1, 3).astype(np.float32),
            'signed_distance': sdf_values.reshape(-1, 1).astype(np.float32),
            'surface_points': valid_surface_points.reshape(-1, 3).astype(np.float32),
            'ray_points': query_ray_points.astype(np.float32),
            'ray_directions': query_ray_directions.astype(np.float32),
            'ray_hitting_gt': ray_hit_values.astype(np.float32),
        }

        assert cur_train_data['query_points'].shape[0] == cur_train_data['signed_distance'].shape[0]
        assert cur_train_data['ray_hitting_gt'].shape[0] == cur_train_data['ray_points'].shape[0] == cur_train_data['ray_directions'].shape[0]

        if visualize:
            # visualize rays
            # occupancy inside points
            inside_idx = np.where(cur_train_data['signed_distance'].squeeze(-1) <= 0)[0]
            inside_idx = np.random.permutation(inside_idx)[:100]
            occ_inside_pnts = cur_train_data['query_points'][inside_idx]
            occ_inside_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(occ_inside_pnts))
            occ_inside_pcd.paint_uniform_color([1, 0, 0])
            # occupancy outside points
            outside_idx = np.where(cur_train_data['signed_distance'].squeeze(-1) > 0)[0]
            outside_idx = np.random.permutation(outside_idx)[:100]
            occ_outside_pnts = cur_train_data['query_points'][outside_idx]
            occ_outside_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(occ_outside_pnts))
            occ_outside_pcd.paint_uniform_color([0, 1, 0])

            surface_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(cur_train_data['surface_points']))
            ray_idx = np.where(cur_train_data['ray_hitting_gt'] < 0)[0]
            ray_idx = np.random.permutation(ray_idx)[:20]
            ray_pnts = np.concatenate([cur_train_data['ray_points'][ray_idx], cur_train_data['ray_points'][ray_idx] + cur_train_data['ray_directions'][ray_idx]], axis=-2)
            ray_pcd = o3d.geometry.LineSet()
            ray_pcd.points = o3d.utility.Vector3dVector(ray_pnts)
            ray_pcd.lines = o3d.utility.Vector2iVector(np.stack([np.arange(len(ray_idx)), np.arange(len(ray_idx))+len(ray_idx)], axis=-1))
            ray_pcd.paint_uniform_color([1, 0, 0])

            ray_idx_hit = np.where(cur_train_data['ray_hitting_gt'] >= 0)[0]
            ray_idx_hit = np.random.permutation(ray_idx_hit)[:20]
            ray_pnts_hit = np.concatenate([cur_train_data['ray_points'][ray_idx_hit], cur_train_data['ray_points'][ray_idx_hit] + cur_train_data['ray_directions'][ray_idx_hit]], axis=-2)
            ray_pcd_hit = o3d.geometry.LineSet()
            ray_pcd_hit.points = o3d.utility.Vector3dVector(ray_pnts_hit)
            ray_pcd_hit.lines = o3d.utility.Vector2iVector(np.stack([np.arange(len(ray_idx_hit)), np.arange(len(ray_idx_hit))+len(ray_idx_hit)], axis=-1))
            ray_pcd_hit.paint_uniform_color([0, 1, 0])
            o3d.visualization.draw_geometries([surface_pcd, ray_pcd, ray_pcd_hit, occ_inside_pcd, occ_outside_pcd])


        all_tsdf_data.append(cur_train_data)
        
        print(f"View {i+1}: Captured depth image and computed TSDF for center pixel.")
        # Optional: visualize each view's point cloud
        # o3d.visualization.draw_geometries([pcd])
    
    all_tsdf_data = jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis=0), *all_tsdf_data)

    if visualize:
        surface_pnts = all_tsdf_data['surface_points']
        vis_pnts = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(surface_pnts))
        o3d.visualization.draw_geometries([vis_pnts])

    # Combine point clouds from all views into a single cloud.
    # combined_pcd = all_point_clouds[0]
    # for pcd in all_point_clouds[1:]:
    #     combined_pcd += pcd
    
    # Visualize the combined point cloud.
    # o3d.visualization.draw_geometries([combined_pcd])

    return all_tsdf_data
    

if __name__ == "__main__":
    # Set random seed for reproducibility.
    seed = 0
    np.random.seed(seed)
    
    depths, cam_pqc, intrinsics, rgbs = sample_pb_scene(14)

    tsdf_data = gen_tsdf_data(depths, cam_pqc, intrinsics, rgbs=rgbs, depth_max=2.0, visualize=False)

    # encode with occupancy data
    import pickle
    import util.model_util as mutil
    import util.reconstruction_util as rcutil
    models = mutil.Models().load_pretrained_models()
    
    oriCORN = encoder.encode_mesh(models, None, tsdf_data, nfps_multiplier=10, depth_max=2.0, niter=4000)
    # with open('assets_oriCORNs/tmp.pkl', 'rb') as f:
    #     oriCORN = pickle.load(f)

    rcutil.create_scene_mesh_from_oriCORNs(oriCORN, dec=jax.jit(models.occ_prediction), level=0.0, qp_bound=2.0, density=200, ndiv=400, visualize=True)