import pybullet as p
import numpy as np
import time
import einops
import jax.numpy as jnp
import jax

import pybullet as p
import numpy as np
from moviepy import *
import time

import pybullet as pb
import os
import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

from modules import shakey_module
import util.broad_phase as broad_phase


def way_points_to_trajectory(waypnts, resolution, cos_transition=True):
    """???"""
    epsilon = 1e-8
    wp_len = jnp.linalg.norm(waypnts[1:] - waypnts[:-1], axis=-1)
    wp_len = wp_len/jnp.sum(wp_len).clip(epsilon)
    wp_len = jnp.where(wp_len<epsilon*10, 0, wp_len)
    wp_len = wp_len/jnp.sum(wp_len)
    wp_len_cumsum = jnp.cumsum(wp_len)
    wp_len_cumsum = jnp.concatenate([jnp.array([0]),wp_len_cumsum], 0)
    wp_len_cumsum = wp_len_cumsum.at[-1].set(1.0)
    indicator = jnp.linspace(0, 1, resolution)
    if cos_transition:
        indicator = (-jnp.cos(indicator*jnp.pi)+1)/2.
    included_idx = jnp.sum(indicator[...,None] > wp_len_cumsum[1:], axis=-1)
    
    upper_residual = (wp_len_cumsum[included_idx+1] - indicator)/wp_len[included_idx].clip(epsilon)
    upper_residual = upper_residual.clip(0.,1.)
    bottom_residual = 1.-upper_residual
    
    traj = waypnts[included_idx] * upper_residual[...,None] + waypnts[included_idx+1] * bottom_residual[...,None]
    traj = jnp.where(wp_len[included_idx][...,None] < 1e-4, waypnts[included_idx], traj)
    traj = traj.at[0].set(waypnts[0])
    traj = traj.at[-1].set(waypnts[-1])
    
    return traj

def evaluate_full_trajectory(control_points, samples_per_segment):
    """
    Evaluate the full trajectory by sampling each segment.
    
    control_points: array of shape (N, 6).
    samples_per_segment: number of sample points per segment.
    Returns: a concatenated array of trajectory values. ((N-1)*samples_per_segment, 6).
    """
    NT = control_points.shape[-2]
    if NT == 2:
        control_points = way_points_to_trajectory(control_points, samples_per_segment, cos_transition=False)
    coeffs = broad_phase.SE3_interpolation_coeffs(control_points)
    return broad_phase.SE3_interpolation_eval(*coeffs, jnp.linspace(0, 1, (NT-1)*samples_per_segment+1, endpoint=True))


def simulate_in_pb(qs, robot_uid, robot: shakey_module.Shakey, sleep_time=0.01, evaluate=False, 
                   obj_in_hand=None, ee_to_obj_pq=None, ignore_collision=[], self_collision=True, video_dir=None,
                   view_matrix=None):
    nintp = qs.shape[-2]
    interpolated_qs = evaluate_full_trajectory(qs, samples_per_segment=400//nintp)[0]
    success = True

    if video_dir is not None:
        # Create the directory if it doesn't exist
        os.makedirs(os.path.dirname(video_dir), exist_ok=True)
        # video_path = os.path.join(video_dir, f"simulation_{robot_uid}.mp4")
        recorder = VideoRecorderMoviePy(filename=video_dir, fps=30, resolution=(640, 480), view_matrix=view_matrix)

    # get entire joint limits
    joint_limit_lower = []
    joint_limit_upper = []
    for joint_idx in range(1, pb.getNumJoints(robot_uid)):
        joint_limit_lower.append(pb.getJointInfo(robot_uid, joint_idx)[8])
        joint_limit_upper.append(pb.getJointInfo(robot_uid, joint_idx)[9])

    joint_limit_lower = np.array(joint_limit_lower)
    joint_limit_upper = np.array(joint_limit_upper)
    max_pen_depth = 0.0
    ignore_object_original_poses = []
    for ignore_uid in ignore_collision:
        original_pose = pb.getBasePositionAndOrientation(ignore_uid)
        ignore_object_original_poses.append(original_pose)
        pb.resetBasePositionAndOrientation(ignore_uid, (0, 0, 1000), [0, 0, 0, 1])
    for nq_itr, q in enumerate(interpolated_qs):
        # if ee_to_obj_pq is not None:
        #     ee_pq = fk(q)
        #     obj_pq = tutil.pq_multi(ee_pq, ee_to_obj_pq)
        # else:
        #     obj_pq = None
        robot.set_q_pb(robot_uid, q, grasped_obj_pbid=obj_in_hand, grasped_obj_pq=ee_to_obj_pq)
        
        # check joint limits
        # if np.any(q[-robot.num_act_joints:] < joint_limit_lower) or np.any(q[-robot.num_act_joints:] > joint_limit_upper):
        #     print('joint limit!')
        #     success = False
        #     if evaluate:
        #         break
        #     continue

        if not evaluate:
            time.sleep(sleep_time)
        pb.performCollisionDetection()

        if video_dir is not None:
            if nq_itr % 4 == 0:
                # Capture the current scene frame
                recorder.record_frame()

        # check collision between robot and obj in hand vs other objects
        collision_pairs = []
        penetration_depth_limit = 0
        target_uids = [robot_uid]
        if obj_in_hand is not None:
            target_uids = target_uids+obj_in_hand
            env_ids = set([pb.getBodyUniqueId(i) for i in range(pb.getNumBodies())]) - set(target_uids)
            env_ids = list(env_ids)

            # contacts = pb.getContactPoints()
            # for contact in contacts:
            #     penetration_depth = contact[8]
            #     source_uid = contact[1]
            #     contact_uid = contact[2]
            #     if penetration_depth < penetration_depth_limit:
            #         if contact_uid in target_uids and source_uid in env_ids:
            #             success = False
            #         if source_uid in target_uids and contact_uid in env_ids:
            #             success = False
            #     if not success:
            #         print(f'collision!: {source_uid}, {contact_uid}')
            #         if evaluate:
            #             break

            for target_uid in target_uids:
                for contact_uid in env_ids:
                    for cp in pb.getClosestPoints(target_uid, contact_uid, distance=0.1):
                        penetration_depth = cp[8]
                        if penetration_depth < penetration_depth_limit:
                            if penetration_depth < max_pen_depth:
                                max_pen_depth = penetration_depth
                            collision_pairs.append((target_uid, contact_uid))

                    # if (pb.getClosestPoints(target_uid, contact_uid, distance=0)):
                    #     collision_pairs.append((target_uid, contact_uid))
                
                if self_collision:
                    for contact_uid in target_uids:
                        if contact_uid!=target_uid:
                            for cp in pb.getClosestPoints(target_uid, contact_uid, distance=0.1):
                                penetration_depth = cp[8]
                                if cp[1] == robot_uid and cp[2] in obj_in_hand:
                                    for cnt, eeidx in enumerate(robot.ee_idx):
                                        if cp[3] == eeidx and cp[2] == obj_in_hand[cnt]:
                                            penetration_depth = 0.001
                                if cp[2] == robot_uid and cp[1] in obj_in_hand:
                                    for cnt, eeidx in enumerate(robot.ee_idx):
                                        if cp[4] == eeidx and cp[1] == obj_in_hand[cnt]:
                                            penetration_depth = 0.001

                                if penetration_depth < penetration_depth_limit:
                                    if penetration_depth < max_pen_depth:
                                        max_pen_depth = penetration_depth
                                        collision_pairs.append((target_uid, contact_uid))

                        # if target_uid == robot_uid and contact_uid == robot_uid:
                        #     if (pb.getClosestPoints(target_uid, contact_uid, distance=0)):
                        #         self_collision_success = False
                        # if target_uid == robot_uid and contact_uid == obj_in_hand:
                        #     if (pb.getClosestPoints(target_uid, contact_uid, distance=0)):
                        #         self_collision_success = False

            if len(collision_pairs) > 0:
                success = False
                if not evaluate:
                    time.sleep(0.25)
                print(f'collision!: {collision_pairs} / {max_pen_depth}')
                # if evaluate:
                #     break
        else:
            env_ids = set([pb.getBodyUniqueId(i) for i in range(pb.getNumBodies())]) - set(target_uids)
            env_ids = list(env_ids)

            for target_uid in target_uids:
                for contact_uid in env_ids:
                    for cp in pb.getClosestPoints(target_uid, contact_uid, distance=0.1):
                        penetration_depth = cp[8]
                        if penetration_depth < penetration_depth_limit:
                            if penetration_depth < max_pen_depth:
                                max_pen_depth = penetration_depth
                            collision_pairs.append((target_uid, contact_uid))

                if self_collision:
                    for contact_uid in target_uids:
                        if contact_uid!=target_uid:
                            for cp in pb.getClosestPoints(target_uid, contact_uid, distance=0.1):
                                penetration_depth = cp[8]
                                if cp[1] == robot_uid and cp[2] in obj_in_hand:
                                    for cnt, eeidx in enumerate(robot.ee_idx):
                                        if cp[3] == eeidx and cp[2] == obj_in_hand[cnt]:
                                            penetration_depth = 0.001
                                if cp[2] == robot_uid and cp[1] in obj_in_hand:
                                    for cnt, eeidx in enumerate(robot.ee_idx):
                                        if cp[4] == eeidx and cp[1] == obj_in_hand[cnt]:
                                            penetration_depth = 0.001

                                if penetration_depth < penetration_depth_limit:
                                    if penetration_depth < max_pen_depth:
                                        max_pen_depth = penetration_depth
                                        collision_pairs.append((target_uid, contact_uid))

            if len(collision_pairs) > 0:
                success = False
                if not evaluate:
                    time.sleep(0.25)
                print(f'collision!: {collision_pairs} / {max_pen_depth}')

            # contacts = pb.getContactPoints(target_uid)
            # for contact in contacts:
            #     penetration_depth = contact[8]
            #     contact_uid = contact[2]
            #     if penetration_depth < penetration_depth_limit:
            #         if penetration_depth < max_pen_depth:
            #             max_pen_depth = penetration_depth
            #         success = False
            #         print(f'collision!: ({target_uid}, {contact_uid}) / {max_pen_depth}')
                    # if contact_uid in target_uids:
                    #     if target_uid == robot_uid and contact_uid == robot_uid:
                    #         print('warn: self collision robot vs robot!')
                    #     if target_uid == robot_uid and contact_uid == obj_in_hand:
                    #         print('warn: self collision robot vs obj in hand!')
                    # if evaluate:
                    #     break
        # if not self_collision_success:
        #     print('warn: self collision!')

    if video_dir is not None:
        recorder.save(str(success))

    for ignore_uid, original_pose in zip(ignore_collision, ignore_object_original_poses):
        pb.resetBasePositionAndOrientation(ignore_uid, original_pose[0], original_pose[1])
    return success, max_pen_depth


class VideoRecorderMoviePy:
    def __init__(self, filename="output.mp4", fps=50, resolution=(1280, 720), view_matrix=None):
        """
        Initializes the video recorder.
        
        Args:
            filename (str): The output video filename.
            fps (int): Frames per second for the video.
            resolution (tuple): (width, height) specifying the video resolution.
        """
        self.filename = filename
        self.fps = fps
        self.resolution = resolution  # (width, height)
        self.view_matrix = view_matrix
        self.frames = []  # List to store individual frame images

    def record_frame(self):
        """
        Captures the current frame from PyBullet's debug visualizer and appends it 
        to the internal frame list. Assumes the PyBullet scene is updated externally.
        """
        # Get the current debug visualizer camera parameters:
        width, height = self.resolution

        if self.view_matrix is not None:
            viewMatrix = self.view_matrix
            projectionMatrix = p.computeProjectionMatrixFOV(fov=70, aspect=self.resolution[0] / self.resolution[1],
                                                            nearVal=0.1, farVal=100)
        else:
            cam_info = p.getDebugVisualizerCamera()
            viewMatrix = cam_info[2]
            projectionMatrix = cam_info[3]

        # Get the image data using PyBullet
        if p.getConnectionInfo()['connectionMethod'] == p.GUI:
            img_arr = p.getCameraImage(width, height, viewMatrix, projectionMatrix, renderer=p.ER_BULLET_HARDWARE_OPENGL)
        else:
            img_arr = p.getCameraImage(width, height, viewMatrix, projectionMatrix)

        # Reshape the flattened pixel list to a (height, width, 4) array (RGBA)
        rgba = np.reshape(img_arr[2], (height, width, 4))

        # Extract the RGB channels (drop the alpha channel)
        rgb_frame = rgba[:, :, :3].astype(np.uint8)

        # Append this frame to our frame list
        self.frames.append(rgb_frame)

    def save(self, prefix=None):
        """
        Generates the video file using MoviePy's ImageSequenceClip with the captured frames.
        """
        if not self.frames:
            print("No frames recorded. Skipping video generation.")
            return

        # Create a video clip from the saved frames using the specified fps
        clip = ImageSequenceClip(self.frames, fps=self.fps)
        # Write the video file. You can adjust the codec as needed.
        if prefix is not None:
            self.filename = os.path.join(os.path.dirname(self.filename), prefix + "_" + os.path.basename(self.filename))
        clip.write_videofile(self.filename, codec="libx264")


# Example usage:
if __name__ == "__main__":
    # Connect to PyBullet with the GUI to visualize the scene.
    p.connect(p.GUI)
    
    # Optionally add the default PyBullet search path for URDFs and textures.
    import pybullet_data
    p.setAdditionalSearchPath(pybullet_data.getDataPath())
    
    # Load a simple scene: a plane and a cube.
    plane_id = p.loadURDF("plane.urdf")
    cube_id = p.loadURDF("cube.urdf", [0, 0, 1])
    
    # Create an instance of the recorder with desired parameters.
    recorder = VideoRecorderMoviePy(filename="pybullet_scene.mp4", fps=30, resolution=(640, 480))
    
    # Record frames for 300 iterations (roughly 10 seconds at 30 fps)
    for _ in range(300):
        p.stepSimulation()      # Update the simulation externally
        recorder.record_frame() # Capture the current scene frame
        time.sleep(1 / 30)      # Delay to roughly match the desired fps
    
    # Save all captured frames into a video file.
    recorder.save()
    
    # Disconnect PyBullet.
    p.disconnect()
