
from torch.utils.data import Dataset
import h5py
import torch 
import numpy as np 
from utils import  pose_to_mat, mat_to_pose10d, convert_pose_mat_rep


class MultiviewDataset(Dataset):
    def __init__(self, hdf5_file, action_chunk_length, cameras, padding = False, mode = "embedding", pad_mode = "zeros",
                 proprio = None): # = ["agentview_image", "robot0_eye_in_hand_image"]):
        super().__init__()
        print(f"Using camera {cameras}")
        self.padding = padding # padding allows you to query to the end of a demosntration regardless of chunk length 

        self.length = 0
        self.proprio = proprio 
        self.h5_file = h5py.File(hdf5_file, "r")
        self.lengths_list = list()
        self.action_chunk_length = action_chunk_length
        self.pad_mode = pad_mode 
        # self.hard_label = hard_label

        self.cameras = cameras 
        self.demos_list = sorted(list(self.h5_file["data"].keys()), key = lambda x: int(x.split("_")[-1]))
        self.build_lengths_list()
        self.length = self.__len__()
        self.mode = mode # "embedding" means returning s, a, s_T. "classifier" means return s, a, Y

        self.sample_distribution = {} 

    
    def build_lengths_list(self):
        for demo in self.demos_list: #self.h5_file["data"]:
            # -1 to acount for s' 
            try:
                if not self.padding: 
                    self.lengths_list.append(self.h5_file["data"][demo]["actions"].shape[0] - (self.action_chunk_length))
                else:
                    self.lengths_list.append(self.h5_file["data"][demo]["actions"].shape[0])
            except:
                print(f"Skipped demo {demo} because it is empty!")

    def __len__(self):
        return sum(self.lengths_list)
    
    def parse_idx(self, idx):
        demo = 0
        remaining_idx = idx 
        while remaining_idx >= self.lengths_list[demo]:
            remaining_idx -= self.lengths_list[demo]
            demo += 1
        return demo, remaining_idx
    
    def get_bounds_of_demo(self, demo):
        cum_demo = np.cumsum(self.lengths_list) 
        start = cum_demo[demo] - self.lengths_list[demo] 
        end = cum_demo[demo]
        return start, end 

    # def __getitem__(self, idx):
    #     demo, remaining_idx = self.parse_idx(idx)
    #     selected_demo = self.h5_file["data"][self.demos_list[demo]]

    #     selected_action = selected_demo["actions"][remaining_idx + self.start_buffer : remaining_idx  + self.start_buffer+ self.action_chunk_length]
    #     # selected_state = selected_demo["obs"]["agentview_image"][remaining_idx + self.start_buffer]
    #     selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx + self.start_buffer], (2, 0, 1)) for camera in self.cameras}
    #     next_selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx + self.start_buffer + self.action_chunk_length], (2, 0, 1)) for camera in self.cameras}
    #     back_window = np.random.randint(self.action_chunk_length, 3 * self.action_chunk_length)
    #     back_index = max(0, remaining_idx - back_window) # small change of selecting something closet
    #     negative_state = {camera : np.transpose(selected_demo["obs"][camera][back_index], (2, 0, 1)) for camera in self.cameras}
    #     # label = 1 if selected_demo["label"][remaining_idx + self.start_buffer] else 0
    #     return selected_state, selected_action.flatten(), next_selected_state, negative_state 

    def smooth_one_hot(self, size, index, epsilon=0.1):
        vector = np.full(size, epsilon / (size - 1))  # Spread probability
        vector[index] = 1 - epsilon  # Assign main probability to the target
        return vector

    def __getitem__(self, idx):
        # if idx not in self.sample_distribution: # logging purposes 
        #     self.sample_distribution[idx] = 0
        # self.sample_distribution[idx] += 1 

        demo, remaining_idx = self.parse_idx(idx)
        selected_demo = self.h5_file["data"][self.demos_list[demo]]
        # selected_action = selected_demo["actions"][remaining_idx  : remaining_idx  + self.action_chunk_length]
        selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx], (2, 0, 1)) for camera in self.cameras}
        last_selected_state = {camera: np.transpose(selected_demo["obs"][camera][-1], (2, 0, 1)) for camera in self.cameras}
        if self.proprio is not None: 
            selected_state[self.proprio] = selected_demo["obs"][self.proprio][remaining_idx]
            last_selected_state[self.proprio] = selected_demo["obs"][self.proprio][-1]

        if "label" in selected_demo:
            task_label = int(selected_demo["label"][remaining_idx]) 

        if self.mode == "classifier": 
            target = self.smooth_one_hot(4, task_label, 0) # not smooth anymore 
        else:
            target = last_selected_state 
        # target = task_label if self.mode == "classifier" else last_selected_state # either return last state or label 

        if not self.padding or remaining_idx + self.action_chunk_length <= self.lengths_list[demo]: 
            selected_actions = selected_demo["actions"][remaining_idx  : remaining_idx  + self.action_chunk_length]

            return selected_state, selected_actions, target
         

        amount_to_pad = remaining_idx + self.action_chunk_length - self.lengths_list[demo]
        selected_actions = np.zeros((self.action_chunk_length, selected_demo["actions"].shape[1]))
        if self.pad_mode == "repeat":
            selected_actions[-amount_to_pad:] = selected_demo["actions"][-1]
        selected_actions[:-amount_to_pad] = selected_demo["actions"][remaining_idx : ]
        return selected_state, selected_actions, target

    def _get_cube_pos(self, idx): # only for pymunk! 
        demo, remaining_idx = self.parse_idx(idx)
        selected_demo = self.h5_file["data"][self.demos_list[demo]]
        return selected_demo["obs"]["states"][remaining_idx], selected_demo["obs"]["agent_pos"][remaining_idx]


    def get_labeled_item(self, idx, flatten_action = False):
        if flatten_action:
            raise Exception("No longer supporting flattening actions!")
        
        demo, remaining_idx = self.parse_idx(idx)
        selected_demo = self.h5_file["data"][self.demos_list[demo]]
        if "label" in selected_demo:
            label = int(selected_demo["label"][remaining_idx]) #1 if selected_demo["label"][remaining_idx] else 0
        elif "behavior" in selected_demo.attrs:
            label = selected_demo.attrs["behavior"]
        else:
            label = None

        selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx], (2, 0, 1)) for camera in self.cameras}
        if self.proprio is not None: 
            selected_state[self.proprio] = selected_demo["obs"][self.proprio][remaining_idx]

        if not self.padding or remaining_idx + self.action_chunk_length <= self.lengths_list[demo]: 
            selected_actions = selected_demo["actions"][remaining_idx  : remaining_idx  + self.action_chunk_length]
            return selected_state, selected_actions, label

        amount_to_pad = remaining_idx + self.action_chunk_length - self.lengths_list[demo]
        selected_actions = np.zeros((self.action_chunk_length, selected_demo["actions"].shape[1]))
        if amount_to_pad == 0:
            selected_actions = selected_demo["actions"][remaining_idx : ]
        else:
            selected_actions[:-amount_to_pad] = selected_demo["actions"][remaining_idx : ]
        
        # this supports the padding mode for repeating 
        if self.pad_mode == "repeat":
            selected_actions[-amount_to_pad:] = selected_demo["actions"][-1]

        return selected_state, selected_actions, label


class MultiviewDatasetUMI(Dataset):
    def __init__(self, hdf5_file, action_chunk_length, cameras, padding = False, mode = "embedding", pad_mode = "zeros",
                 proprio = None, downsample = 3, end_sampling = 10): # = ["agentview_image", "robot0_eye_in_hand_image"]):
        super().__init__()
        print(f"Using camera {cameras}")
        self.padding = padding # padding allows you to query to the end of a demosntration regardless of chunk length 

        self.length = 0
        self.proprio = proprio 
        self.h5_file = h5py.File(hdf5_file, "r")
        self.lengths_list = list()
        self.action_chunk_length = action_chunk_length
        self.pad_mode = pad_mode 
        assert self.pad_mode == "repeat" # the only thing that makes sense for this UMI setup 
        # self.hard_label = hard_label

        self.cameras = cameras 
        self.demos_list = sorted(list(self.h5_file["data"].keys()), key = lambda x: int(x.split("_")[-1]))
        self.build_lengths_list()
        self.length = self.__len__()
        self.mode = mode # "embedding" means returning s, a, s_T. "classifier" means return s, a, Y

        self.sample_distribution = {} 
        self.downsample = downsample 

        self.end_sampling = end_sampling 

    
    def build_lengths_list(self):
        for demo in self.demos_list: #self.h5_file["data"]:
            # -1 to acount for s' 
            try:
                if not self.padding: 
                    self.lengths_list.append(self.h5_file["data"][demo].attrs["num_samples"] - (self.action_chunk_length))
                else:
                    self.lengths_list.append(self.h5_file["data"][demo].attrs["num_samples"])
            except:
                print(f"Skipped demo {demo} because it is empty!")

    def __len__(self):
        return sum(self.lengths_list)
    
    def parse_idx(self, idx):
        demo = 0
        remaining_idx = idx 
        while remaining_idx >= self.lengths_list[demo]:
            remaining_idx -= self.lengths_list[demo]
            demo += 1
        return demo, remaining_idx
    
    def get_bounds_of_demo(self, demo):
        cum_demo = np.cumsum(self.lengths_list) 
        start = cum_demo[demo] - self.lengths_list[demo] 
        end = cum_demo[demo]
        return start, end 

    # def __getitem__(self, idx):
    #     demo, remaining_idx = self.parse_idx(idx)
    #     selected_demo = self.h5_file["data"][self.demos_list[demo]]

    #     selected_action = selected_demo["actions"][remaining_idx + self.start_buffer : remaining_idx  + self.start_buffer+ self.action_chunk_length]
    #     # selected_state = selected_demo["obs"]["agentview_image"][remaining_idx + self.start_buffer]
    #     selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx + self.start_buffer], (2, 0, 1)) for camera in self.cameras}
    #     next_selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx + self.start_buffer + self.action_chunk_length], (2, 0, 1)) for camera in self.cameras}
    #     back_window = np.random.randint(self.action_chunk_length, 3 * self.action_chunk_length)
    #     back_index = max(0, remaining_idx - back_window) # small change of selecting something closet
    #     negative_state = {camera : np.transpose(selected_demo["obs"][camera][back_index], (2, 0, 1)) for camera in self.cameras}
    #     # label = 1 if selected_demo["label"][remaining_idx + self.start_buffer] else 0
    #     return selected_state, selected_action.flatten(), next_selected_state, negative_state 

    def smooth_one_hot(self, size, index, epsilon=0.1):
        vector = np.full(size, epsilon / (size - 1))  # Spread probability
        vector[index] = 1 - epsilon  # Assign main probability to the target
        return vector

    def __getitem__(self, idx):
        # if idx not in self.sample_distribution: # logging purposes 
        #     self.sample_distribution[idx] = 0
        # self.sample_distribution[idx] += 1 
        k = 16 * self.downsample # matches the state that will be in after executing the 16 actions! 

        demo, remaining_idx = self.parse_idx(idx)
        selected_demo = self.h5_file["data"][self.demos_list[demo]]
        # selected_action = selected_demo["actions"][remaining_idx  : remaining_idx  + self.action_chunk_length]
        num_forward_steps = self.lengths_list[demo] - remaining_idx 
        end_sampling_limit = min(num_forward_steps, self.end_sampling)
        

        selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx], (2, 0, 1)) for camera in self.cameras}
        # last_selected_state = {camera: np.transpose(selected_demo["obs"][camera][-np.random.randint(1, end_sampling_limit + 1)], (2, 0, 1)) for camera in self.cameras}
        k_steps_forward = min(remaining_idx + k, self.lengths_list[demo] - 1)
        last_selected_state = {camera: np.transpose(selected_demo["obs"][camera][k_steps_forward], (2, 0, 1)) for camera in self.cameras}
        # print(remaining_idx, k_steps_forward)
      
        if self.proprio is not None: 
            selected_state[self.proprio] = selected_demo["obs"][self.proprio][remaining_idx]
            # last_selected_state[self.proprio] = selected_demo["obs"][self.proprio][-np.random.randint(1, end_sampling_limit + 1)]
            last_selected_state[self.proprio] = selected_demo["obs"][self.proprio][k_steps_forward]

        target = last_selected_state 
        # target = task_label if self.mode == "classifier" else last_selected_state # either return last state or label 
        if not self.padding or remaining_idx + self.downsample * (self.action_chunk_length + 1) <= self.lengths_list[demo]: 
            # get the next chunk over!
            selected_poses = selected_demo["obs/robot0_eef_pos"][remaining_idx + self.downsample   : remaining_idx  + self.downsample * (self.action_chunk_length + 1) : self.downsample]
            selected_gripper = selected_demo["obs/robot0_gripper_width"][remaining_idx + self.downsample  : remaining_idx  + self.downsample * (self.action_chunk_length + 1) : self.downsample]
            selected_axis_angle = selected_demo["obs/robot0_eef_rot_axis_angle"][remaining_idx + self.downsample  : remaining_idx  + self.downsample * (self.action_chunk_length + 1) : self.downsample]

            selected_actions = self.proprio_to_action(selected_poses, selected_axis_angle, selected_gripper, relative = True)
            return selected_state, selected_actions, target
         

        amount_to_pad = (remaining_idx + self.downsample * (self.action_chunk_length + 1) - self.lengths_list[demo]) // self.downsample 
        # 
        selected_actions = np.zeros((self.action_chunk_length, 10))

        selected_poses = selected_demo["obs/robot0_eef_pos"][remaining_idx  + self.downsample : : self.downsample]
        selected_gripper = selected_demo["obs/robot0_gripper_width"][remaining_idx + self.downsample  : : self.downsample]
        selected_axis_angle = selected_demo["obs/robot0_eef_rot_axis_angle"][remaining_idx + self.downsample  : : self.downsample]
        if selected_poses.shape[0] == 0: # HACK edge case at the last state, action is zero 
            return selected_state, selected_actions, target
 
        valid_actions = self.proprio_to_action(selected_poses, selected_axis_angle,  selected_gripper, relative = True)

        if amount_to_pad == 0:#HACK edge case handling 
            selected_actions = valid_actions 
        else: 
            selected_actions[:-amount_to_pad] = valid_actions
        if self.pad_mode == "repeat": # else give zeros 
            selected_actions[-amount_to_pad:] = valid_actions[-1]
        
        return selected_state, selected_actions, target

    def proprio_to_action(self, eef_pos, axis_angle, gripper_width, relative = True):
        # expects padded eef_pos to a certain length 
        # this should already be skip-sampled; we aren't handling skip sampling here 

        action = np.zeros((eef_pos.shape[0], 10))
        pose_mat = pose_to_mat(np.concatenate([eef_pos, axis_angle], axis=-1)) # makes homeomorphic matrix 

        if relative:
            # zero_idx = eef_pos_indices.index(0)
            rel_pose_mat = convert_pose_mat_rep(
                pose_mat,
                base_pose_mat=pose_mat[0],
                pose_rep="relative",
                backward=False,
            )
            pose = mat_to_pose10d(rel_pose_mat)
        else:
            raise Exception("Shouldn't be here")
            pose = mat_to_pose10d(pose_mat)

        action[:, :-1] = pose
        action[:, -1:] = gripper_width
        return action 

    def get_labeled_item(self, idx, flatten_action = False):
        if flatten_action:
            raise Exception("No longer supporting flattening actions!")
        
        demo, remaining_idx = self.parse_idx(idx)
        selected_demo = self.h5_file["data"][self.demos_list[demo]]
        if "label" in selected_demo:
            label = int(selected_demo["label"][remaining_idx]) #1 if selected_demo["label"][remaining_idx] else 0
        elif "behavior" in selected_demo.attrs:
            label = selected_demo.attrs["behavior"]
        else:
            label = None

        selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx], (2, 0, 1)) for camera in self.cameras}
        if self.proprio is not None: 
            selected_state[self.proprio] = selected_demo["obs"][self.proprio][remaining_idx]

        # target = task_label if self.mode == "classifier" else last_selected_state # either return last state or label 
        if not self.padding or remaining_idx + self.downsample * (self.action_chunk_length + 1) <= self.lengths_list[demo]: 
            # get the next chunk over!
            selected_poses = selected_demo["obs/robot0_eef_pos"][remaining_idx + self.downsample   : remaining_idx  + self.downsample * (self.action_chunk_length + 1) : self.downsample]
            selected_gripper = selected_demo["obs/robot0_gripper_width"][remaining_idx + self.downsample  : remaining_idx  + self.downsample * (self.action_chunk_length + 1) : self.downsample]
            selected_axis_angle = selected_demo["obs/robot0_eef_rot_axis_angle"][remaining_idx + self.downsample  : remaining_idx  + self.downsample * (self.action_chunk_length + 1) : self.downsample]

            selected_actions = self.proprio_to_action(selected_poses, selected_axis_angle, selected_gripper, relative = True)
            return selected_state, selected_actions, label
         

        amount_to_pad = (remaining_idx + self.downsample * (self.action_chunk_length + 1) - self.lengths_list[demo]) // self.downsample 
        # 
        selected_actions = np.zeros((self.action_chunk_length, 10))

        selected_poses = selected_demo["obs/robot0_eef_pos"][remaining_idx  + self.downsample : : self.downsample]
        selected_gripper = selected_demo["obs/robot0_gripper_width"][remaining_idx + self.downsample  : : self.downsample]
        selected_axis_angle = selected_demo["obs/robot0_eef_rot_axis_angle"][remaining_idx + self.downsample  : : self.downsample]
        if selected_poses.shape[0] == 0: # HACK edge case at the last state, action is zero 
            return selected_state, selected_actions, label
 
        valid_actions = self.proprio_to_action(selected_poses, selected_axis_angle,  selected_gripper, relative = True)

        if amount_to_pad == 0:#HACK edge case handling 
            selected_actions = valid_actions 
        else: 
            selected_actions[:-amount_to_pad] = valid_actions
        if self.pad_mode == "repeat": # else give zeros 
            selected_actions[-amount_to_pad:] = valid_actions[-1]

        return selected_state, selected_actions, label



# class MultiviewARDataset(Dataset):
#     def __init__(self, hdf5_file, action_chunk_length, cameras, padding = False, negative_samples = False): # = ["agentview_image", "robot0_eye_in_hand_image"]):
#         super().__init__()
#         print(f"Using camera {cameras}")
#         self.length = 0
#         self.h5_file = h5py.File(hdf5_file, "r")
#         self.lengths_list = list()
#         self.action_chunk_length = action_chunk_length

#         self.cameras = cameras         
#         self.padding = padding 


#         self.build_lengths_list()
#         self.length = self.__len__()
#         self.demos_list = list(self.h5_file["data"].keys())
#         self.negative_samples = negative_samples

    
#     def build_lengths_list(self):
#         for demo in self.h5_file["data"]:
#             # -1 to acount for s' 
#             if self.padding:
#                 self.lengths_list.append(self.h5_file["data"][demo]["actions"].shape[0]) # the 1 is for the extra state 
#             else: 
#                 self.lengths_list.append(self.h5_file["data"][demo]["actions"].shape[0] - (self.action_chunk_length) - 1) # the 1 is for the extra state 

#     def __len__(self):
#         return sum(self.lengths_list)
    
#     def parse_idx(self, idx):
#         demo = 0
#         remaining_idx = idx 
#         while remaining_idx >= self.lengths_list[demo]:
#             remaining_idx -= self.lengths_list[demo]
#             demo += 1
#         return demo, remaining_idx


#     def __getitem__(self, idx):
#         demo, remaining_idx = self.parse_idx(idx)
#         selected_demo = self.h5_file["data"][self.demos_list[demo]]
#         negative_states = torch.zeros([1])
#         if self.negative_samples:
#             K = 5
#             selected_index = np.random.randint(self.action_chunk_length, 3 * self.action_chunk_length)
#             if remaining_idx + selected_index + K >= self.lengths_list[demo]:
#                 direction = -1
#             elif remaining_idx - selected_index < 0:
#                 direction = 1
#             else:
#                 direction = 1 if np.random.rand() > 0.5 else -1 

            
#             back_index = remaining_idx + direction * selected_index # small change of selecting something closet
#             back_index = min(max(0, back_index), self.lengths_list[demo] - 1) # clip in case of some crazy edge cases 
#             negative_states = {camera: np.transpose(selected_demo["obs"][camera][back_index : back_index + K], (0, 3, 1, 2)) for camera in self.cameras}
#             # print(back_index, back_index + K, self.lengths_list[demo], negative_states["robot0_eye_in_hand_image"].shape)
#             # print(negative_states["robot0_eye_in_hand_image"].shape)
#             # negative_states = {camera : np.transpose(selected_demo["obs"][camera][back_index], (2, 0, 1)) for camera in self.cameras}


#         if not self.padding or remaining_idx + self.action_chunk_length < self.lengths_list[demo]: 
#             selected_actions = selected_demo["actions"][remaining_idx  : remaining_idx  + self.action_chunk_length]
#             selected_states = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx  : remaining_idx  + self.action_chunk_length + 1], (0, 3, 1, 2)) for camera in self.cameras}
#             return selected_states, selected_actions, negative_states


#         amount_to_pad = remaining_idx + self.action_chunk_length - self.lengths_list[demo]
#         selected_actions = np.zeros((self.action_chunk_length, selected_demo["actions"].shape[1]))
#         selected_states = {camera : np.zeros((self.action_chunk_length + 1, selected_demo["obs"][camera].shape[3], 
#                                               selected_demo["obs"][camera].shape[1],
#                                               selected_demo["obs"][camera].shape[2]), dtype = np.uint8) for camera in self.cameras}
#         if amount_to_pad == 0:
#             selected_actions = selected_demo["actions"][remaining_idx : ]
#         else:
#             selected_actions[:-amount_to_pad] = selected_demo["actions"][remaining_idx : ]
        
#         for camera in self.cameras: 
#             selected_states[camera][:-(amount_to_pad + 1)] = np.transpose(selected_demo["obs"][camera][remaining_idx  : remaining_idx  + self.action_chunk_length + 1], (0, 3, 1, 2))
#             selected_states[camera][-(amount_to_pad + 1):] = selected_states[camera][-(amount_to_pad + 2)].copy() 
#         # selected_states = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx  : remaining_idx  + self.action_chunk_length + 1], (0, 3, 1, 2)) for camera in self.cameras}
#         # print(selected_states["robot0_eye_in_hand_image"].shape, selected_actions.)
#         return selected_states, selected_actions, negative_states


# class MultiviewLabeledDataset(Dataset):
#     def __init__(self, hdf5_file, action_chunk_length, cameras):
#         super().__init__()
#         self.length = 0
#         self.h5_file = h5py.File(hdf5_file, "r")
#         self.lengths_list = list()
#         self.action_chunk_length = action_chunk_length
#         self.cameras = cameras 

#         self.start_buffer = 0 # debugging to ignore the first X number of steps 

#         self.build_lengths_list()
#         self.length = self.__len__()
#         self.demos_list = list(self.h5_file["data"].keys())

    
#     def build_lengths_list(self):
#         for demo in self.h5_file["data"]:
#             # -1 to acount for s' 
#             self.lengths_list.append(self.h5_file["data"][demo]["actions"].shape[0] - (self.action_chunk_length) - self.start_buffer)

#     def __len__(self):
#         return sum(self.lengths_list)
    
#     def parse_idx(self, idx):
#         demo = 0
#         remaining_idx = idx 
#         while remaining_idx >= self.lengths_list[demo]:
#             remaining_idx -= self.lengths_list[demo]
#             demo += 1
#         return demo, remaining_idx

#     def __getitem__(self, idx):
#         raise Exception("Not modified yet!")
#         demo, remaining_idx = self.parse_idx(idx)
#         selected_demo = self.h5_file["data"][self.demos_list[demo]]
#         selected_action = selected_demo["actions"][remaining_idx + self.start_buffer : remaining_idx  + self.start_buffer+ self.action_chunk_length]
#         # selected_state = selected_demo["obs"]["agentview_image"][remaining_idx + self.start_buffer]
#         selected_state = selected_demo["obs"][self.camera][remaining_idx + self.start_buffer]
#         next_selected_state = selected_demo["obs"][self.camera][remaining_idx + self.start_buffer + self.action_chunk_length]
#         # label = 1 if selected_demo["label"][remaining_idx + self.start_buffer] else 0
#         return np.transpose(selected_state, (2, 0, 1)), selected_action.flatten(), np.transpose(next_selected_state, (2, 0, 1))
    
#     def get_labeled_item(self, idx):
#         demo, remaining_idx = self.parse_idx(idx)
#         selected_demo = self.h5_file["data"][self.demos_list[demo]]
#         selected_action = selected_demo["actions"][remaining_idx + self.start_buffer : remaining_idx  + self.start_buffer+ self.action_chunk_length]
#         selected_state = {camera: np.transpose(selected_demo["obs"][camera][remaining_idx + self.start_buffer], (2, 0, 1)) for camera in self.cameras}

#         label = 1 if selected_demo["label"][remaining_idx + self.start_buffer] else 0
#         return selected_state, selected_action.flatten(), label 


# dataset = MultiviewARDataset("/yourfolderhere/dataset/scripted_multi_cube/multicube_reach_larger/data.hdf5", 
#                              action_chunk_length = 8, cameras = ["robot0_eye_in_hand_image"], padding = True)

# dataset = MultiviewDataset("/yourfolderhere/dataset/pymunktouch/pymunk_touch_res128_largercubes_valid_1k/data.hdf5", 
#                            action_chunk_length = 16, cameras = ["image"], padding = True, mode = "classifier", pad_mode = "repeat")

# dataset = MultiviewDatasetUMI("/yourfolderhere/dataset/UMI_Cup/data_valid.hdf5", 
#                            action_chunk_length = 16, cameras = ["camera0_rgb"], padding = True, mode = "classifier", pad_mode = "repeat", downsample = 3)

# import tqdm 
# # import matplotlib.pyplot as plt 
# for i in tqdm.tqdm(range(len(dataset))):
#     state, action, _ = dataset.__getitem__(i)

# # for i in tqdm.tqdm(range(600)):
# #     state, action, _ = dataset.__getitem__(i)
# #     # import ipdb 
# #     # ipdb.set_trace()
# #     # plt.plot(action[:, :2], alpha = 0.3)
# #     assert action.shape[0] == 16
# np.set_printoptions(precision=5, suppress = True)
# for i in range(50):
#     state, action, _ = dataset.__getitem__(i * 50)
#     print(action)
#     # import ipdb 
#     # ipdb.set_trace()
#     # plt.plot(action[:, :2], alpha = 0.3)
#     # plt.plot(action[:, 0], action[:, 1], linestyle='-', marker='o',alpha = 0.1)
#     # assert action.shape[0] == 16
# # plt.savefig("test.png")
# print("done!")
