import collections
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils
import robosuite.utils.transform_utils as T

from matplotlib import pyplot as plt
import os
import os.path as osp
import h5py
from tqdm import tqdm
import numpy as np
import imageio
import cv2

import robosuite.utils.transform_utils as trans

def stable_orientation_error(des_quat, curr_quat):
    """
    Computes a stable axis-angle vector for the rotation taking 'curr_quat' -> 'des_quat'.

    Both quaternions must be [w, x, y, z]. The returned axis-angle is guaranteed to have
    angle in [0, pi].
    """
    # 1) Compute the raw relative rotation: q_err = des_quat * inverse(curr_quat).
    inv_curr = np.array([curr_quat[0], -curr_quat[1], -curr_quat[2], -curr_quat[3]])  # if curr_quat is normalized
    q_err = trans.quat_multiply(des_quat, inv_curr)  # also in [w, x, y, z]

    # 2) If q_err.w < 0, flip sign so that w >= 0.
    #    This way, the angle from quat2axisangle is in [0, pi].
    if q_err[0] < 0.0:
        q_err *= -1.0

    # 3) Convert the quaternion difference to axis-angle. This returns a 3D vector: axis * angle
    delta_axisangle = trans.quat2axisangle(q_err)

    return delta_axisangle


def move_eef_to_pose(
    env,
    controller,
    curr_pos,
    curr_quat,
    desired_pos,            # (3,) 
    desired_quat,           # (4,) => [w, x, y, z], normalized
    pos_threshold=0.02,
    rot_threshold=0.02,
    max_iterations=1000,
):
    """
    Drives EEF from current pose to (desired_pos, desired_quat) at max speed allowed 
    by the OSC input range. Uses stable orientation error to avoid flipping.

    Returns True if converged, else False.
    """
    # Make sure desired_quat is normalized, just in case
    desired_quat = desired_quat / np.linalg.norm(desired_quat)

    # For convenience, read input clamp range
    input_min, input_max = controller.input_min, controller.input_max
    print('input_min ', input_min)

    def clamp_6d(raw_delta_6):
        """
        Scales the raw 6D error vector uniformly for translation and rotation separately,
        so that each part does not exceed its symmetric bounds.
        
        Assumes that:
        - For translation, input_min[0:3] = -t_bound and input_max[0:3] = t_bound.
        - For rotation, input_min[3:6] = -r_bound and input_max[3:6] = r_bound.
        
        This preserves the direction of the error in each subspace.
        
        Args:
            raw_delta_6 (np.array): The raw 6D error vector [dx, dy, dz, d_rx, d_ry, d_rz].
            input_min (np.array): Lower bounds (6,); e.g. [-t_bound, -t_bound, -t_bound, -r_bound, -r_bound, -r_bound].
            input_max (np.array): Upper bounds (6,); e.g. [t_bound, t_bound, t_bound, r_bound, r_bound, r_bound].
        
        Returns:
            np.array: A scaled 6D error vector that satisfies the bounds for translation and rotation separately.
        """
        # Handle translation (first three components)
        trans_bound = np.abs(input_min[0])  # assume symmetric
        trans_error = raw_delta_6[:3]
        trans_max = np.max(np.abs(trans_error))
        # if trans_max > trans_bound:
        trans_error = trans_error * (trans_bound / trans_max)
        
        # Handle rotation (last three components)
        rot_bound = np.abs(input_min[3])
        rot_error = raw_delta_6[3:]
        rot_max = np.max(np.abs(rot_error))
        # if rot_max > rot_bound:
        rot_error = rot_error * (rot_bound / rot_max)
        
        return np.concatenate([trans_error, rot_error], axis=0)

    for i in range(max_iterations):
        # 1) Update internal state so the controller knows the latest EEF pose
        # controller.update()

        # 2) Current position, orientation
        print('======= {} ======='.format(i))
        print('curr_pos ', curr_pos)
        print('goal_pos ', desired_pos)
        # print('curr_quat ', curr_quat)

        # 3) Compute position error
        pos_error = desired_pos - curr_pos
        dist_to_go = np.linalg.norm(pos_error)

        # 4) Compute stable orientation error in axis-angle
        delta_axisangle = stable_orientation_error(desired_quat, curr_quat)
        rot_to_go = np.linalg.norm(delta_axisangle)
        # print(f"dist_to_go: {dist_to_go:.4f}, rot_to_go: {rot_to_go:.4f}")
        # Debug prints if you want:
        print(f"dist_to_go: {dist_to_go:.4f}, rot_to_go: {rot_to_go:.4f}")

        # 5) Check if we are within tolerances
        if dist_to_go < pos_threshold and rot_to_go < rot_threshold:
            print("Reached desired pose!")
            return True
        # 3) We want to step "as fast as possible", i.e. saturate the 6D delta at input_max/min
        # Build our 6D delta => [dx, dy, dz, dax, day, daz]
        raw_delta_6 = np.concatenate([pos_error, delta_axisangle], axis=0)
        # print('raw_delta_6 ', raw_delta_6)
        # The 7th dimension is the gripper; we keep it closed => -1.0
        # raw_action_7d = np.concatenate([raw_delta_6, [GRIPPER_CLOSE_CMD]])

        # 4) Now clamp the entire 7D action to fall within input_min, input_max
        #    This ensures the action fed to the controller is in the valid input range
        print('raw trans ', raw_delta_6[:3])
        clamped_action_6d = clamp_6d(raw_delta_6)
        clamped_action_7d = np.concatenate([clamped_action_6d, [-1]])
        print('clamped trans ', clamped_action_7d[:3])
        # 5) Step the environment with this action
        #    (Internally the controller will scale action_7d from [input_min,input_max]
        #     to [output_min,output_max], then compute torques.)
        obs, _, _, _ = env.step(clamped_action_7d)
        # curr_pos = env.env.robots[0].controller.ee_pos
        # curr_quat = trans.mat2quat(env.env.robots[0].controller.ee_ori_mat)
        curr_pos = obs['robot0_eef_pos']
        curr_quat = obs['robot0_eef_quat']

    print("Warning: move_eef_to_pose() did not converge within max_iterations!")
    return False



def create_env(env_meta, shape_meta, enable_render=True):
    modality_mapping = collections.defaultdict(list)
    for key, attr in shape_meta['obs'].items():
        modality_mapping[attr.get('type', 'low_dim')].append(key)
    ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)

    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False, 
        render_offscreen=enable_render,
        use_image_obs=enable_render, 
    )
    return env

def path_planning(
    output,
    file_path):
    with open("square_model_file_1.4.xml", 'r') as f:
        model_xml = f.read()
    # exit()         

    shape_meta = {'obs': 
                  {'agentview_image': {'shape': [3, 224, 224], 'type': 'rgb'}, 
                   'robot0_eye_in_hand_image': {'shape': [3, 224, 224], 'type': 'rgb'}, 
                   'object': {'shape': [14]},
                   'robot0_eef_pos': {'shape': [3]}, 
                   'robot0_eef_quat': {'shape': [4]}, 
                   'robot0_gripper_qpos': {'shape': [2]}}}
    
    env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=file_path)
    # env_meta['env_kwargs']['controller_configs']['control_delta'] = 'abs' not in file_path
    env_meta['env_kwargs']['controller_configs']['control_delta'] = True
    # env_meta['env_kwargs']['controller_configs']['uncouple_pos_ori'] = True
    # env_meta['env_kwargs']['use_object_obs'] = False
    # env_meta['env_kwargs']['camera_heights'] = 240
    # env_meta['env_kwargs']['camera_widths'] = 240
    # env_meta['env_kwargs']['camera_names'] = ['robot0_eye_in_hand', 'sideview']

    # env_meta['env_kwargs']['controller_configs']['kp'] = 50
    # env_meta['env_kwargs']['controller_configs']['damping_ratio'] = 0.5

    # env_meta['env_kwargs']['camera_heights'] = 240
    # env_meta['env_kwargs']['camera_widths'] = 240


    print('env_meta ', env_meta)
    env = create_env(
        env_meta=env_meta, 
        shape_meta=shape_meta
    )

    success_ids = []
    with h5py.File(file_path, 'r') as file:
        print('total sample ', file['data'].attrs["total"])
        # exit()
        demos = file['data']

        for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"):
            if i == 1: break
            demo = demos[f'demo_{i}']
            initial_state = dict(states=demo['states'][0])
            initial_state["model"] = model_xml
            env.reset()
            obs = env.reset_to(initial_state)

            print('current eef pos ', obs['robot0_eef_pos'])
            print('current eef quat ', obs['robot0_eef_quat'])

            # curr_pos = env.env.robots[0].controller.ee_pos
            # curr_quat = trans.mat2quat(env.env.robots[0].controller.ee_ori_mat)
            curr_pos    = obs['robot0_eef_pos']
            curr_quat   = obs['robot0_eef_quat']
            goal_pos = demo['obs']['robot0_eef_pos'][50]        
            goal_quat = demo['obs']['robot0_eef_quat'][50]
            print('goal_pos ', goal_pos)
            print('goal_quat ', goal_quat)
            move_eef_to_pose(env, env.env.robots[0].controller, curr_pos, curr_quat, goal_pos, goal_quat)
        #     for j in range(len(demo['states'])):
        #         # agentview_img_saved = (demo['obs']['agentview_image'][j]).astype(np.uint8)
        #         agentview_img_saved = (demo['obs']['sideview_image'][j]).astype(np.uint8)
        #         # agentview_img_saved = (demo['obs']['robot0_eye_in_hand_image'][j]).astype(np.uint8)
        #         agentview_images_saved.append(agentview_img_saved)
        #         # print('{} th demo, {} th step'.format(i, j))
        #         action = demo['actions'][j]
        #         # abs_action = demo['abs_actions'][j]
        #         # current_state = env.env.sim.get_state().flatten()
        #         # obs, reward, _, _ = env.step(abs_action)       
        #         # agentview_img_rendered = (np.transpose(obs['robot0_eye_in_hand_image'], (1, 2, 0)) * 255).astype(np.uint8)    
        #         # agentview_images_saved.append(agentview_img_rendered)

        #         # state_playback = env.env.sim.get_state().flatten()
        #         # print('state_playback ', state_playback[:6])
        #         # print('state saved ', demo['states'][j+1][:6])
        #         # if j < len(demo['abs_actions']) - 1 and not np.all(np.equal(demo['states'][j + 1], state_playback)):
        #         #     err = np.linalg.norm(demo['states'][j + 1] - state_playback)
        #         #     print(f"[warning] playback diverged by {err:.3f} for ep {i} at step {j}")
        #         #     if err > 2:
        #         #         print('Error > 2, ending episode {}...'.format(i))
        #         #         break

        #         # #     agentview_img_saved = (demo['obs']['agentview_image'][j+1]).astype(np.uint8)
        #         # #     eye_in_hand_img_saved = (demo['obs']['robot0_eye_in_hand_image'][j+1]).astype(np.uint8)
        #         # #     agentview_img_rendered = (np.transpose(obs['agentview_image'], (1, 2, 0)) * 255).astype(np.uint8)
        #         # #     # agentview_img_rendered = agentview_img_rendered[:256, 64:, :]
        #         # #     difference_mask = agentview_img_saved != agentview_img_rendered  # True where pixel values differ
        #         # #     diff_coordinates = np.argwhere(difference_mask)
        #         # #     diff_coordinates_list = [tuple(coord) for coord in diff_coordinates]

        #         # #     print("Num. of coordinates where pixel values differ: ", len(diff_coordinates_list))
        #         # #     print('Norm of img difference: ', np.linalg.norm(agentview_img_saved - agentview_img_rendered) / (3*84*84*255))

        #         # #     agentview_images_rendered.append(agentview_img_rendered)
        #         # #     agentview_images_saved.append(agentview_img_saved)
        #         # #     eye_in_hand_images_saved.append(eye_in_hand_img_saved)

        #         # # # if j < len(demo['actions']) - 1:
        #         # # #     next_state = demo['states'][j + 1]
        #         # # #     obs = env.reset_to({"states" : next_state})

        #         # # #     agentview_img_saved = (demo['obs']['agentview_image'][j+1]).astype(np.uint8)
        #         # # #     eye_in_hand_img_saved = (demo['obs']['robot0_eye_in_hand_image'][j+1]).astype(np.uint8)
        #         # # #     img_h = agentview_img_saved.shape[1]
        #         # # #     agentview_img_rendered = (np.transpose(obs['agentview_image'], (1, 2, 0)) * 255).astype(np.uint8)
        #         # # #     difference_mask = agentview_img_saved != agentview_img_rendered  # True where pixel values differ
        #         # # #     diff_coordinates = np.argwhere(difference_mask)
        #         # # #     diff_coordinates_list = [tuple(coord) for coord in diff_coordinates]
        #         # # #     print("Num. of coordinates where pixel values differ: ", len(diff_coordinates_list))

        #         # # #     state_playback = env.env.sim.get_state().flatten()
        #         # # #     err = np.linalg.norm(demo['states'][j + 1] - state_playback)
        #         # # #     print(f"[warning] playback diverged by {err:.3f} for ep {i} at step {j}")
        #         # # #     if err > 2:
        #         # # #         print('Error > 2, ending episode {}...'.format(i))
        #         # # #         break
        #         # # #     agentview_images_rendered.append(agentview_img_rendered)
        #         # # #     agentview_images_saved.append(agentview_img_saved)
        #         # # #     eye_in_hand_images_saved.append(eye_in_hand_img_saved)

        #     # Save episode data to replay buffer
        #     if True:
        #         # for j in range(len(agentview_images_saved)):
        #         #     cv2.imwrite(
        #         #         osp.join('visualizations/tool_hang/tool_hang_ph_demo_v141_image_abs_0.12', 
        #         #                 'episode_{}_{}.png'.format(i, j)), 
        #         #         cv2.cvtColor(agentview_images_saved[j], cv2.COLOR_RGB2BGR)
        #         #     )
        #         agentview_images_saved = np.array(agentview_images_saved)
        #         print('agentview_images_saved.shape ', agentview_images_saved.shape)
        #         # save_agentview_video(agentview_images_saved, osp.join('visualizations/tool_hang/tool_hang_ph_demo_v141_image_tune_cam_0.09_fov_100', 'episode_{}.mp4'.format(i,)))
        #         save_agentview_video(agentview_images_saved, osp.join('visualizations/tool_hang/tool_hang_ph_demo_v141_image_abs_tune_cam_0.10_fov_110_last_two_val', 'episode_{}_side.mp4'.format(i,)))
        #     success = env.env._check_success()
        #     if success:
        #         success_ids.append(i)
        #     print('success_ids ', success_ids)
        # print('success_ids ', success_ids)
        #     # break


def save_agentview_video(images, video_path):
    """
    Save a list of agentview images as a video after resizing to 96x96.
    Args:
        images: List of numpy arrays of shape (84, 84, 3).
        video_path: Path to save the video.
    """
    # resized_frames = [cv2.resize(img, (96, 96)) for img in images]
    imageio.mimsave(video_path, images, fps=30)
