import copy
import os
import os.path as osp
from diffusion_policy.dataset.robomimic_replay_image_dataset import _convert_actions
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
import h5py
from tqdm import tqdm
import numpy as np

from diffusion_policy.common.replay_buffer import ReplayBuffer
from collect_square_exp_data import save_agentview_video

def create_replay_buffer(output, hdf5_file_path, string):
    replay_buffer = ReplayBuffer.create_from_path(output, mode='a')
    rotation_transformer = RotationTransformer(
        from_rep='axis_angle', to_rep='rotation_6d')
    with h5py.File(hdf5_file_path, 'r') as file:
        print('total sample ', file['data'].attrs["total"])
        demos = file['data']
        for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"):
            demo = demos[f'demo_{i}']
            raw_actions = demo['actions'][:]
            if 'abs_actions' in demo.keys():
                # print('this branch')
                raw_actions = demo['abs_actions'][:]
            converted_actions = _convert_actions(
                raw_actions=copy.deepcopy(raw_actions),
                abs_action=True,
                rotation_transformer=rotation_transformer
            )
            
            print('demo[obs] ', demo['obs'])
            episode = list()
            agentview_images = list()
            eye_in_hand_images = list()
            for j in range(len(demo['actions'])):
                action = converted_actions[j]
                agentview_images.append(demo['obs']['agentview_image'][j])
                eye_in_hand_images.append(demo['obs']['robot0_eye_in_hand_image'][j])
                episode.append({
                    'agentview_image': demo['obs']['agentview_image'][j],
                    'robot0_eye_in_hand_image': demo['obs']['robot0_eye_in_hand_image'][j],
                    'robot0_eef_pos': demo['obs']['robot0_eef_pos'][j],
                    'robot0_eef_quat': demo['obs']['robot0_eef_quat'][j],
                    'robot0_gripper_qpos': demo['obs']['robot0_gripper_qpos'][j],
                    'raw_action': demo['actions'][j],
                    'converted_action': action,
                })

            data_dict = dict()
            for key in episode[0].keys():
                data_dict[key] = np.stack([x[key] for x in episode])
            replay_buffer.add_episode(data_dict, compressors='disk')

            img_h = agentview_images[0].shape[0]
            video_path_view_1 = osp.join(f'visualizations/square/test_data_collection/{string}', 'episode_{}_view_1_{}.mp4'.format(i, img_h))
            # video_path_view_2 = osp.join(f'visualizations/square/test_data_collection/{string}', 'episode_{}_view_2_160.mp4'.format(i))
            agentview_images = np.array(agentview_images)
            # eye_in_hand_images = np.array(eye_in_hand_images)
            print('agentview_images shape ', agentview_images.shape)
            # print('eye_in_hand_images shape ', eye_in_hand_images.shape)
            save_agentview_video(agentview_images, video_path_view_1)
            # save_agentview_video(eye_in_hand_images, video_path_view_2)
