from typing import Dict, List

import numpy
import torch
import numpy as np
import h5py
from tqdm import tqdm
import zarr
import os
import shutil
import copy
import json
import hashlib
from filelock import FileLock
from threadpoolctl import threadpool_limits
import concurrent.futures
import multiprocessing
from omegaconf import OmegaConf
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.dataset.base_dataset import BaseImageDataset, LinearNormalizer
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.codecs.imagecodecs_numcodecs import register_codecs, Jpeg2k
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.common.sampler import SequenceSampler, get_val_mask
from diffusion_policy.common.normalize_util import (
    robomimic_abs_action_only_normalizer_from_stat,
    robomimic_abs_action_only_dual_arm_normalizer_from_stat,
    get_range_normalizer_from_stat,
    get_image_range_normalizer,
    get_identity_normalizer_from_stat,
    array_to_stats
)
import random
register_codecs()

class RobomimicReplayImageDataset(BaseImageDataset):
    def __init__(self,
            shape_meta: dict,
            dataset_path: str,
            horizon=1,
            pad_before=0,
            pad_after=0,
            n_obs_steps=None,
            abs_action=False,
            rotation_rep='rotation_6d', # ignored when abs_action=False
            use_legacy_normalizer=False,
            use_cache=False,
            seed=42,
            val_ratio=0.0,
                 extra=False,
                 subdataset=1
        ):
        rotation_transformer = RotationTransformer(
            from_rep='axis_angle', to_rep=rotation_rep)

        replay_buffer = None
        if use_cache:
            cache_zarr_path = dataset_path + '.zarr.zip'
            cache_lock_path = cache_zarr_path + '.lock'
            print('Acquiring lock on cache.')
            with FileLock(cache_lock_path):
                if not os.path.exists(cache_zarr_path):
                    # cache does not exists
                    try:
                        print('Cache does not exist. Creating!')
                        # store = zarr.DirectoryStore(cache_zarr_path)
                        replay_buffer = _convert_robomimic_to_replay(
                            store=zarr.MemoryStore(), 
                            shape_meta=shape_meta, 
                            dataset_path=dataset_path, 
                            abs_action=abs_action, 
                            rotation_transformer=rotation_transformer,
                        extra=extra)
                        print('Saving cache to disk.')
                        with zarr.ZipStore(cache_zarr_path) as zip_store:
                            replay_buffer.save_to_store(
                                store=zip_store
                            )
                        # breakpoint()
                    except Exception as e:
                        shutil.rmtree(cache_zarr_path)
                        raise e
                else:
                    print('Loading cached ReplayBuffer from Disk.')
                    with zarr.ZipStore(cache_zarr_path, mode='r') as zip_store:
                        replay_buffer = ReplayBuffer.copy_from_store(
                            src_store=zip_store, store=zarr.MemoryStore())
                    print('Loaded!')
        else:
            replay_buffer = _convert_robomimic_to_replay(
                store=zarr.MemoryStore(), 
                shape_meta=shape_meta, 
                dataset_path=dataset_path, 
                abs_action=abs_action, 
                rotation_transformer=rotation_transformer)
        # print(list(replay_buffer.keys()))
        # ['action', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_eye_in_hand_image', 'robot0_gripper_qpos', 'robot1_eef_pos',
        #  'robot1_eef_quat', 'robot1_eye_in_hand_image', 'robot1_gripper_qpos', 'shouldercamera0_image', 'shouldercamera1_image']
        # print(type(replay_buffer['action']))
        # print((replay_buffer['action'].shape))
        # (195800, 20)
        # print((replay_buffer['robot0_eef_pos'].shape))
        # (195800, 3)
        # print((replay_buffer['robot0_eye_in_hand_image'].shape))
        # (195800, 84, 84, 3)
        # 改变数据集大小replay_buffer['action'].resize(200000,20)
        # replay_buffer['action'][195800]
        # array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        #        0., 0., 0.], dtype=float32)
        # replay_buffer['action'][195799]
        # array([1.08451724e-01, -3.17977043e-04, 1.24131846e+00, -8.95779371e-01,
        #        -4.42363769e-01, -4.35195863e-02, -4.38639134e-01, 8.63876581e-01,
        #        2.47614503e-01, -1.00000000e+00, -1.82160810e-01, 4.26901519e-01,
        #        1.00427496e+00, 9.89905357e-01, 1.77873056e-02, -1.40609235e-01,
        #        1.73192192e-02, -9.99839783e-01, -4.55209333e-03, -1.00000000e+00],
        #       dtype=float32)
        
        # import torch
        # a=torch.rand((2,10))
        # print(a)
        # replay_buffer['action'][:2]=a
        # b=np.random.rand(2,10)
        # print(b)
        # replay_buffer['action'][2:4]=b
        

        
        rgb_keys = list()
        lowdim_keys = list()
        obs_shape_meta = shape_meta['obs']
        for key, attr in obs_shape_meta.items():
            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                rgb_keys.append(key)
            elif type == 'low_dim':
                lowdim_keys.append(key)

        # subdataset
        subdataset=int(subdataset)
        if subdataset>1:

            ori_episode_ends = replay_buffer['/meta']['episode_ends'][:]
            ori_dataset_len = len(ori_episode_ends)

            subdataset_episode_start_ends = [(0, ori_episode_ends[1])]

            subdataset_list = np.linspace(2, ori_dataset_len - 1, ori_dataset_len // subdataset).astype(int)
            subdataset_episode_start_ends.extend([(ori_episode_ends[i - 1], ori_episode_ends[i]) for i in subdataset_list])

            episode_len = [end - start for start, end in subdataset_episode_start_ends]
            from itertools import accumulate

            new_episode_ends = list(accumulate(episode_len))
            replay_buffer['/meta']['episode_ends'] = np.array(new_episode_ends)

            # replay_buffer.meta['episode_ends']=np.array(new_episode_ends)
            for key in list(obs_shape_meta.keys()):

                subdataset_key_arraylist = [replay_buffer[key][start:end] for start, end in subdataset_episode_start_ends]
                subdataset_key_array = np.concatenate(subdataset_key_arraylist, axis=0)

                replay_buffer[key].resize(subdataset_key_array.shape[0], *replay_buffer[key].shape[1:])

                replay_buffer[key][:] = subdataset_key_array


            subdataset_key_arraylist = [replay_buffer['action'][start:end] for start, end in subdataset_episode_start_ends]
            subdataset_key_array = np.concatenate(subdataset_key_arraylist, axis=0)

            replay_buffer['action'].resize(subdataset_key_array.shape[0], *replay_buffer['action'].shape[1:])

            replay_buffer['action'][:] = subdataset_key_array


        # for key in rgb_keys:
        #     replay_buffer[key].compressor.numthreads=1

        key_first_k = dict()
        self.lowdim_keys=lowdim_keys
        self.rgb_keys=rgb_keys
        if n_obs_steps is not None:
            # only take first k obs from images
            for key in rgb_keys + lowdim_keys:
                key_first_k[key] = n_obs_steps



        self.val_ratio=val_ratio
        self.seed=seed
        val_mask = get_val_mask(
            n_episodes=replay_buffer.n_episodes, 
            val_ratio=val_ratio,
            seed=seed)
        train_mask = ~val_mask
        sampler = SequenceSampler(
            replay_buffer=replay_buffer, 
            sequence_length=horizon,
            pad_before=pad_before, 
            pad_after=pad_after,
            episode_mask=train_mask,
            key_first_k=key_first_k)
        self.replay_buffer = replay_buffer
        self.sampler = sampler
        
        # self.replay_buffer=None
        
        self.shape_meta = shape_meta
        self.rgb_keys = rgb_keys
        self.lowdim_keys = lowdim_keys
        self.abs_action = abs_action
        self.n_obs_steps = n_obs_steps
        self.train_mask = train_mask
        self.horizon = horizon
        self.pad_before = pad_before
        self.pad_after = pad_after
        self.use_legacy_normalizer = use_legacy_normalizer


        self.aux_data={'obs':{},'action':None}

        # self.ori_len_buffer = self.replay_buffer['robot0_eye_in_hand_image'].shape[0]

        # len_obs_new_data=obs_dict['robot0_eye_in_hand_image'].shape[0]*self.n_obs_steps

        self.ori_len_buffer = self.replay_buffer['action'].shape[0]
        self.ori_episode_num=self.replay_buffer.n_episodes
        print('current_episode_long:',self.ori_episode_num)
        self.ori_episode_ends = self.replay_buffer['/meta/episode_ends']

        self.check_array=[]
        for i in range(len(self.ori_episode_ends)):
            if i==0:
                end=self.ori_episode_ends[i]
                # end=self.ori_episode_ends[i+1]
                self.check_array.append(self.replay_buffer['action'][:end][:])
            else:
                start=self.ori_episode_ends[i-1]
                end=self.ori_episode_ends[i]
                self.check_array.append(self.replay_buffer['action'][start:end][:])



        # if self.ori_episode_num<100:
        self.whole_epoch_num=450
        # else:
        #     self.extra_episode_num=75

        # self.extra_episode=[]

        # self.aug_dataset()
        # self.init_data_test()
        


    def get_validation_dataset(self):
        val_set = copy.copy(self)
        val_set.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer, 
            sequence_length=self.horizon,
            pad_before=self.pad_before, 
            pad_after=self.pad_after,
            episode_mask=~self.train_mask
            )
        val_set.train_mask = ~self.train_mask
        return val_set

    def get_normalizer(self, **kwargs) -> LinearNormalizer:
        normalizer = LinearNormalizer()

        # action
        stat = array_to_stats(self.replay_buffer['action'])
        if self.abs_action:
            if stat['mean'].shape[-1] > 10:
                # dual arm
                this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
            else:
                this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)
            
            if self.use_legacy_normalizer:
                this_normalizer = normalizer_from_stat(stat)
        else:
            # already normalized
            this_normalizer = get_identity_normalizer_from_stat(stat)
        normalizer['action'] = this_normalizer

        # obs
        for key in self.lowdim_keys:
            stat = array_to_stats(self.replay_buffer[key])

            if key.endswith('pos'):
                this_normalizer = get_range_normalizer_from_stat(stat)
            elif key.endswith('quat'):
                # quaternion is in [-1,1] already
                this_normalizer = get_identity_normalizer_from_stat(stat)
            elif key.endswith('qpos'):
                this_normalizer = get_range_normalizer_from_stat(stat)
            else:
                raise RuntimeError('unsupported')
            normalizer[key] = this_normalizer

        # image
        for key in self.rgb_keys:
            normalizer[key] = get_image_range_normalizer()
        return normalizer

    def get_all_actions(self) -> torch.Tensor:
        return torch.from_numpy(self.replay_buffer['action'])

    def __len__(self):
        # if not self.aux_data['obs']:
        return len(self.sampler)
        # else:
        #     return len(self.sampler)+self.aux_data['action'].shape[0]
    # def aug_dataset(self):
    #     threadpool_limits(1)
    #     for idx in tqdm(range(len(self.sampler)), desc="Loading data"):
    #     # for idx in range(len(self.sampler)):
    #         data = self.sampler.sample_sequence(idx)
    #         # sampler.sample_sequence(0)['action'].shape
    #         # (16, 20)
    #         # sampler.sample_sequence(0)['robot0_eef_quat'].shape
    #         # (16, 4)
    #         # sampler.sample_sequence(0)['robot0_eye_in_hand_image'].shape
    #         # (16, 84, 84, 3)
    #
    #         # to save RAM, only return first n_obs_steps of OBS
    #         # since the rest will be discarded anyway.
    #         # when self.n_obs_steps is None
    #         # this slice does nothing (takes all)
    #         T_slice = slice(self.n_obs_steps)
    #
    #         obs_dict = dict()
    #         for key in self.rgb_keys:
    #             # move channel last to channel first
    #             # T,H,W,C
    #             # convert uint8 image to float32
    #             obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
    #                 ).astype(np.float32) / 255.
    #             # T,C,H,W
    #             del data[key]
    #         for key in self.lowdim_keys:
    #             obs_dict[key] = data[key][T_slice].astype(np.float32)
    #             del data[key]
    #
    #         torch_data = {
    #             'obs': dict_apply(obs_dict, torch.from_numpy),
    #             'action': torch.from_numpy(data['action'].astype(np.float32))
    #         }
    #         self.aux_data.append(torch_data)

    def finetune_data(self, data):
        """
        data=

        dict{'obs':,'action':},....
        obs
                        # robot0_eef_pos
                # shape: torch.Size([traj_len, 3])
                # robot0_eef_quat
                # shape: torch.Size([traj_len, 4])
                # robot0_eye_in_hand_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # robot0_gripper_qpos
                # shape: torch.Size([traj_len, 2])
                # robot1_eef_pos
                # shape: torch.Size([traj_len, 3])
                # robot1_eef_quat
                # shape: torch.Size([traj_len, 4])
                # robot1_eye_in_hand_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # robot1_gripper_qpos
                # shape: torch.Size([traj_len, 2])
                # shouldercamera0_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # shouldercamera1_image
                # shape: torch.Size([traj_len, 3, 84, 84])
        action:
                # shape: torch.Size([traj_len, 16, 14])



        """

        # if len(data_list)>20:
        #     len_traj_list=20
        # else:

        # self.sampler.sample_sequence(0)['robot1_gripper_qpos']
        # array([[0.020833, -0.020833],
        #        [0.02109402, -0.02115064],
        #        [0.0231193, -0.02303653],
        #        [0.02637962, -0.0263418],
        #        [0.0304168, -0.03037665],
        #        [0.0340717, -0.03403934],
        #        [0.0362446, -0.03621264],
        #        [0.03745825, -0.03744098],
        #        [0.03813783, -0.03812835],
        #        [0.03852304, -0.03850771]], dtype=float32)
        # self.sampler.sample_sequence(188855)['robot1_gripper_qpos'].shape
        # (10, 2)
        # self.replay_buffer['robot1_gripper_qpos'][:10]
        # array([[0.020833, -0.020833],
        #        [0.02109402, -0.02115064],
        #        [0.0231193, -0.02303653],
        #        [0.02637962, -0.0263418],
        #        [0.0304168, -0.03037665],
        #        [0.0340717, -0.03403934],
        #        [0.0362446, -0.03621264],
        #        [0.03745825, -0.03744098],
        #        [0.03813783, -0.03812835],
        #        [0.03852304, -0.03850771]], dtype=float32)

        # self.extra_episode.extend(data_list)
        # while len(self.extra_episode)>self.extra_episode_num:
        # current_extra_episode_num=len(self.extra_episode)
        for demo_array in self.check_array:
            if np.array_equal(data['actions'][:], demo_array):
                return False

        current_episode_num = self.replay_buffer.n_episodes + 1
        if current_episode_num > self.whole_epoch_num:
            # breakpoint()
            obs_dict = {}
            for key in data['obs'].keys():
                obs_dict[key] = data['obs'][key][:]
            action = data['actions'][:]  # traj,20

            traj_len = action.shape[0]

            # action=action.repeat(50,1)   # traj,20
            delete_episode_len = self.replay_buffer['/meta']['episode_ends'][self.ori_episode_num] - \
                                 self.replay_buffer['/meta']['episode_ends'][self.ori_episode_num - 1]
            reserve_episode = self.replay_buffer['/meta']['episode_ends'][self.ori_episode_num + 1:]
            revised_episode_ends = list(map(lambda x: x - delete_episode_len, reserve_episode))

            previous_dataset_len = revised_episode_ends[-1]
            current_dataset_len = previous_dataset_len + traj_len

            self.replay_buffer['/meta']['episode_ends'][self.ori_episode_num:-1] = revised_episode_ends

            self.replay_buffer['/meta']['episode_ends'][-1] = current_dataset_len

            for key in self.rgb_keys:
                obs_dict[key] = obs_dict[key] * 255
                obs_dict[key] = obs_dict[key].astype(np.uint8)
                # if obs_dict[key].shape[2]!=3:
                #     assert False,'dimension false'
                obs_dict[key] = np.moveaxis(obs_dict[key], 1, -1)

            for key in obs_dict.keys():
                # self.replay_buffer[key].resize(self.ori_len_buffer+traj_len,*self.replay_buffer[key].shape[1:])

                # self.replay_buffer[key][current_episode_len:]=obs_dict[key].reshape(-1,*obs_dict[key].shape[2:])   # traj_len,n_obs_step
                array = self.replay_buffer[key][self.ori_len_buffer + delete_episode_len:]

                self.replay_buffer[key][self.ori_len_buffer:previous_dataset_len] = array

                self.replay_buffer[key].resize(current_dataset_len, *self.replay_buffer[key].shape[1:])

                self.replay_buffer[key][previous_dataset_len:] = obs_dict[key]

            array = self.replay_buffer['action'][self.ori_len_buffer + delete_episode_len:]

            self.replay_buffer['action'][self.ori_len_buffer:previous_dataset_len] = array

            self.replay_buffer['action'].resize(current_dataset_len, *self.replay_buffer['action'].shape[1:])

            self.replay_buffer['action'][previous_dataset_len:] = action

        else:

            obs_dict = {}
            for key in data['obs'].keys():
                obs_dict[key] = data['obs'][key][:]
            action = data['actions'][:]  # traj,20
            # print(len(self.sampler), 'previous')

            # print(action.shape,'action.shjape')

            traj_len = action.shape[0]

            # for key in obs_dict.keys():
            # obs_shape=obs_dict[key].shape
            # obs_dict[key] = obs_dict[key]
            for key in self.rgb_keys:
                # #print(key)
                # #print(obs_dict[key].shape)
                # if not isinstance(obs_dict[key], np.ndarray):
                #     #print(obs_dict[key],'error')
                #     assert False
                obs_dict[key] = obs_dict[key] * 255
                obs_dict[key] = obs_dict[key].astype(np.uint8)
                # if obs_dict[key].shape[2]!=3:
                #     assert False,'dimension false'
                obs_dict[key] = np.moveaxis(obs_dict[key], 1, -1)
            # action=action

            previous_dataset_len = self.replay_buffer['/meta']['episode_ends'][-1]
            # len_action_new_data=action.shape[0]*action.shape[1]
            current_dataset_len = previous_dataset_len + traj_len
            # print(traj_len,'traj_len')
            # print(previous_dataset_len,'previous_dataset_len')
            # print(current_dataset_len,'current_dataset_len')

            for key in obs_dict.keys():
                self.replay_buffer[key].resize(current_dataset_len, *self.replay_buffer[key].shape[1:])

                # self.replay_buffer[key][current_episode_len:]=obs_dict[key].reshape(-1,*obs_dict[key].shape[2:])   # traj_len,n_obs_step
                self.replay_buffer[key][previous_dataset_len:] = obs_dict[key]

            # print(self.replay_buffer['action'].shape, 'previosu buffer action')
            self.replay_buffer['action'].resize(current_dataset_len,
                                                *self.replay_buffer['action'].shape[1:])
            self.replay_buffer['action'][previous_dataset_len:] = action
            # print(action.shape,'extra action shape')
            # print(self.replay_buffer['action'].shape, 'after buffer action')

            array = self.replay_buffer['/meta/episode_ends']
            # print(self.replay_buffer['/meta/episode_ends'][-10:], 'meta/episode_ends_old')
            if not self.replay_buffer['/meta'].get('episode_ends_new'):
                self.replay_buffer['/meta'].create('episode_ends_new', shape=(self.replay_buffer.n_episodes + 1,),
                                                   dtype='int64')
            else:
                self.replay_buffer['/meta']['episode_ends_new'].resize(self.replay_buffer.n_episodes + 1, )
            self.replay_buffer['/meta']['episode_ends_new'][:-1] = array[:]

            self.replay_buffer['/meta']['episode_ends_new'][-1] = array[-1] + traj_len
            self.replay_buffer['/meta']['episode_ends'] = self.replay_buffer['/meta']['episode_ends_new']
            # print(self.replay_buffer['/meta/episode_ends'][-10:], 'meta/episode_ends_new')

        key_first_k = dict()

        if self.n_obs_steps is not None:
            # only take first k obs from images
            for key in self.rgb_keys + self.lowdim_keys:
                key_first_k[key] = self.n_obs_steps
        val_mask = get_val_mask(
            n_episodes=self.replay_buffer.n_episodes,
            val_ratio=self.val_ratio,
            seed=self.seed)
        train_mask = ~val_mask
        sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=self.horizon,
            pad_before=self.pad_before,
            pad_after=self.pad_after,
            episode_mask=train_mask,
            key_first_k=key_first_k)

        self.sampler = sampler

        # print(self.replay_buffer['action'].shape,'self.replay_bufferaction')
        if self.replay_buffer['action'].shape[0] != self.replay_buffer['robot0_eye_in_hand_image'].shape[0]:
            assert False, 'dimension false'
        if self.replay_buffer['action'].shape[0] != self.replay_buffer['/meta']['episode_ends'][-1]:
            assert False, 'dimension false testes'

        start=self.replay_buffer['/meta']['episode_ends'][-2]
        end=self.replay_buffer['/meta']['episode_ends'][-1]
        self.check_array.append(self.replay_buffer['action'][start:end][:])
        # print(len(self.sampler),'after')
        # print("One more epoch data added...", end="")
        return True


    def init_data(self, data, first=False):
        """
        data=

        dict{'obs':,'action':},....
        obs
                        # robot0_eef_pos
                # shape: torch.Size([traj_len, 3])
                # robot0_eef_quat
                # shape: torch.Size([traj_len, 4])
                # robot0_eye_in_hand_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # robot0_gripper_qpos
                # shape: torch.Size([traj_len, 2])
                # robot1_eef_pos
                # shape: torch.Size([traj_len, 3])
                # robot1_eef_quat
                # shape: torch.Size([traj_len, 4])
                # robot1_eye_in_hand_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # robot1_gripper_qpos
                # shape: torch.Size([traj_len, 2])
                # shouldercamera0_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # shouldercamera1_image
                # shape: torch.Size([traj_len, 3, 84, 84])
        action:
                # shape: torch.Size([traj_len, 16, 14])



        """

        # if len(data_list)>20:
        #     len_traj_list=20
        # else:

        # self.sampler.sample_sequence(0)['robot1_gripper_qpos']
        # array([[0.020833, -0.020833],
        #        [0.02109402, -0.02115064],
        #        [0.0231193, -0.02303653],
        #        [0.02637962, -0.0263418],
        #        [0.0304168, -0.03037665],
        #        [0.0340717, -0.03403934],
        #        [0.0362446, -0.03621264],
        #        [0.03745825, -0.03744098],
        #        [0.03813783, -0.03812835],
        #        [0.03852304, -0.03850771]], dtype=float32)
        # self.sampler.sample_sequence(188855)['robot1_gripper_qpos'].shape
        # (10, 2)
        # self.replay_buffer['robot1_gripper_qpos'][:10]
        # array([[0.020833, -0.020833],
        #        [0.02109402, -0.02115064],
        #        [0.0231193, -0.02303653],
        #        [0.02637962, -0.0263418],
        #        [0.0304168, -0.03037665],
        #        [0.0340717, -0.03403934],
        #        [0.0362446, -0.03621264],
        #        [0.03745825, -0.03744098],
        #        [0.03813783, -0.03812835],
        #        [0.03852304, -0.03850771]], dtype=float32)
        # print(len(self.sampler),'before')

        # len_traj_list=len(data_list)
        # self.extra_episode_num=375-len_traj_list
        # for traj_num in range(len_traj_list):
        #     if traj_num==0:
        #         data = data_list[traj_num]
        if first:
            obs_dict = {}
            for key in data['obs'].keys():
                obs_dict[key] = data['obs'][key][:]
            action = data['actions'][:]  # traj,20
            # action=action.repeat(50,1)   # traj,20

            traj_len = action.shape[0]

            # if not self.aux_data['obs']:
            #     for key in obs_dict.keys():
            #         self.aux_data['obs'][key] = obs_dict[key]
            # else:
            #     for key in obs_dict.keys():
            #         self.aux_data['obs'][key]=torch.cat((self.aux_data['obs'][key],obs_dict[key]),dim=0)
            # if self.aux_data['actionss'] is None:
            #     self.aux_data['actionss'] = action
            # else:
            #     self.aux_data['actionss']=torch.cat((self.aux_data['actionss'],action),dim=0)
            # # self.aux_data = self.aux_data + data_list
            #
            for key in obs_dict.keys():
                # obs_shape=obs_dict[key].shape
                obs_dict[key] = obs_dict[key]
            for key in self.rgb_keys:
                obs_dict[key] = obs_dict[key] * 255
                obs_dict[key] = obs_dict[key].astype(np.uint8)
                # if obs_dict[key].shape[2]!=3:
                #     assert False,'dimension false'
                obs_dict[key] = np.moveaxis(obs_dict[key], 1, -1)

            for key in obs_dict.keys():
                self.replay_buffer[key].resize(traj_len, *self.replay_buffer[key].shape[1:])

                # self.replay_buffer[key][current_episode_len:]=obs_dict[key].reshape(-1,*obs_dict[key].shape[2:])   # traj_len,n_obs_step
                self.replay_buffer[key][:] = obs_dict[key]

            self.replay_buffer['action'].resize(traj_len, *self.replay_buffer['action'].shape[1:])
            self.replay_buffer['action'][:] = action

            # array = self.replay_buffer['/meta/episode_ends']
            if not self.replay_buffer['/meta'].get('episode_ends_new'):
                self.replay_buffer['/meta'].create('episode_ends_new', shape=(1,), dtype='int64')
            else:
                self.replay_buffer['/meta']['episode_ends_new'].resize(1, )
            self.replay_buffer['/meta']['episode_ends_new'][-1] = traj_len
            self.replay_buffer['/meta']['episode_ends'] = self.replay_buffer['/meta']['episode_ends_new']

        else:
            obs_dict = {}
            for key in data['obs'].keys():
                obs_dict[key] = data['obs'][key][:]
            action = data['actions'][:]  # traj,201
            # action=action.repeat(50,1)   # traj,20

            traj_len = action.shape[0]

            # if not self.aux_data['obs']:
            #     for key in obs_dict.keys():
            #         self.aux_data['obs'][key] = obs_dict[key]
            # else:
            #     for key in obs_dict.keys():
            #         self.aux_data['obs'][key]=torch.cat((self.aux_data['obs'][key],obs_dict[key]),dim=0)
            # if self.aux_data['actionss'] is None:
            #     self.aux_data['actionss'] = action
            # else:
            #     self.aux_data['actionss']=torch.cat((self.aux_data['actionss'],action),dim=0)
            # # self.aux_data = self.aux_data + data_list
            #
            for key in self.rgb_keys:
                obs_dict[key] = obs_dict[key] * 255
                obs_dict[key] = obs_dict[key].astype(np.uint8)
                # if obs_dict[key].shape[2]!=3:
                #     assert False,'dimension false'
                obs_dict[key] = np.moveaxis(obs_dict[key], 1, -1)
            # action=action

            previous_dataset_len = self.replay_buffer['/meta']['episode_ends'][-1]
            current_dataset_len = previous_dataset_len + traj_len

            # len_obs_new_data=obs_dict['robot0_eye_in_hand_image'].shape[0]*self.n_obs_steps

            # current_episode_len=self.replay_buffer['action'].shape[0]
            # len_action_new_data=action.shape[0]*action.shape[1]

            for key in obs_dict.keys():
                self.replay_buffer[key].resize(current_dataset_len, *self.replay_buffer[key].shape[1:])

                # self.replay_buffer[key][current_episode_len:]=obs_dict[key].reshape(-1,*obs_dict[key].shape[2:])   # traj_len,n_obs_step
                self.replay_buffer[key][previous_dataset_len:] = obs_dict[key]

            self.replay_buffer['action'].resize(current_dataset_len, *self.replay_buffer['action'].shape[1:])
            self.replay_buffer['action'][previous_dataset_len:] = action
            array = self.replay_buffer['/meta/episode_ends']

            self.replay_buffer['/meta']['episode_ends_new'].resize(self.replay_buffer.n_episodes + 1, )
            self.replay_buffer['/meta']['episode_ends_new'][:-1] = array[:]

            self.replay_buffer['/meta']['episode_ends_new'][-1] = array[-1] + traj_len
            self.replay_buffer['/meta']['episode_ends'] = self.replay_buffer['/meta']['episode_ends_new']

        key_first_k = dict()
        if self.n_obs_steps is not None:
            # only take first k obs from images
            for key in self.rgb_keys + self.lowdim_keys:
                key_first_k[key] = self.n_obs_steps
        val_mask = get_val_mask(
            n_episodes=self.replay_buffer.n_episodes,
            val_ratio=self.val_ratio,
            seed=self.seed)
        train_mask = ~val_mask
        sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=self.horizon,
            pad_before=self.pad_before,
            pad_after=self.pad_after,
            episode_mask=train_mask,
            key_first_k=key_first_k)

        self.sampler = sampler
        # print(len(self.sampler),'after')

        self.ori_len_buffer = self.replay_buffer['robot0_eye_in_hand_image'].shape[0]
        self.ori_episode_num = self.replay_buffer.n_episodes
        self.ori_episode_ends = self.replay_buffer['/meta/episode_ends']

        # print("One more epoch data added...", end="")
        return True

    def finetune_data_buffer(self, replay_buffer,key_list):


        extra_replay_buffer_len = replay_buffer['action'].shape[0]
        extra_episode_num = replay_buffer.n_episodes
        extra_episode_ends = replay_buffer['/meta/episode_ends']

        # extra_episode_len = replay_buffer['/meta/episode_ends']

        # obs_dict = {}
        # for key in data['obs'].keys():
        #     obs_dict[key] = data['obs'][key][:]
        # action = data['actions'][:]  # traj,20
        # traj_len = action.shape[0]


        previous_dataset_len = self.replay_buffer['/meta']['episode_ends'][-1]
        previous_episode_ends = self.replay_buffer['/meta']['episode_ends']
        previous_episode_num = self.replay_buffer.n_episodes

        extra_episode_ends = [previous_dataset_len + x for x in extra_episode_ends]


        current_dataset_len=previous_dataset_len+extra_replay_buffer_len

        self.replay_buffer['/meta'].create('episode_ends_new', shape=(self.replay_buffer.n_episodes + extra_episode_num,),
                                           dtype='int64')


        self.replay_buffer['/meta']['episode_ends_new'][previous_episode_num:] = np.array(extra_episode_ends)
        print('previopus episode len',len(self.replay_buffer['/meta']['episode_ends']))
        self.replay_buffer['/meta']['episode_ends']=self.replay_buffer['/meta']['episode_ends_new']

        print('after episode len',len(self.replay_buffer['/meta']['episode_ends']))





        for key in key_list:
            self.replay_buffer[key].resize(current_dataset_len, *self.replay_buffer[key].shape[1:])

            self.replay_buffer[key][previous_dataset_len:] = replay_buffer[key][:]

        # print(self.replay_buffer['/meta/episode_ends'][-10:], 'meta/episode_ends_new')

        key_first_k = dict()

        if self.n_obs_steps is not None:
            # only take first k obs from images
            for key in self.rgb_keys + self.lowdim_keys:
                key_first_k[key] = self.n_obs_steps
        val_mask = get_val_mask(
            n_episodes=self.replay_buffer.n_episodes,
            val_ratio=self.val_ratio,
            seed=self.seed)
        train_mask = ~val_mask
        sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=self.horizon,
            pad_before=self.pad_before,
            pad_after=self.pad_after,
            episode_mask=train_mask,
            key_first_k=key_first_k)

        self.sampler = sampler

        # print(self.replay_buffer['action'].shape,'self.replay_bufferaction')
        if self.replay_buffer['action'].shape[0] != self.replay_buffer['robot0_eye_in_hand_image'].shape[0]:
            assert False, 'dimension false'
        if self.replay_buffer['action'].shape[0] != self.replay_buffer['/meta']['episode_ends'][-1]:
            assert False, 'dimension false testes'

        # print(len(self.sampler),'after')

        self.check_array=[]
        for i in range(len(self.replay_buffer['/meta']['episode_ends'])):
            if i==0:
                # start = self.replay_buffer['/meta']['episode_ends'][i]
                end = self.replay_buffer['/meta']['episode_ends'][i]
                self.check_array.append(self.replay_buffer['action'][:end][:])
            else:
                start=self.replay_buffer['/meta']['episode_ends'][i-1]
                end=self.replay_buffer['/meta']['episode_ends'][i]
                self.check_array.append(self.replay_buffer['action'][start:end][:])


        return True
    def init_data_test(self):
        """
        datalist=[

        dict{'obs':,'action':},....
        obs
                        # robot0_eef_pos
                # shape: torch.Size([traj_len, 3])
                # robot0_eef_quat
                # shape: torch.Size([traj_len, 4])
                # robot0_eye_in_hand_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # robot0_gripper_qpos
                # shape: torch.Size([traj_len, 2])
                # robot1_eef_pos
                # shape: torch.Size([traj_len, 3])
                # robot1_eef_quat
                # shape: torch.Size([traj_len, 4])
                # robot1_eye_in_hand_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # robot1_gripper_qpos
                # shape: torch.Size([traj_len, 2])
                # shouldercamera0_image
                # shape: torch.Size([traj_len, 3, 84, 84])
                # shouldercamera1_image
                # shape: torch.Size([traj_len, 3, 84, 84])
        action:
                # shape: torch.Size([traj_len, 16, 14])



        """
        
        # if len(data_list)>20:
        #     len_traj_list=20
        # else:

        # self.sampler.sample_sequence(0)['robot1_gripper_qpos']
        # array([[0.020833, -0.020833],
        #        [0.02109402, -0.02115064],
        #        [0.0231193, -0.02303653],
        #        [0.02637962, -0.0263418],
        #        [0.0304168, -0.03037665],
        #        [0.0340717, -0.03403934],
        #        [0.0362446, -0.03621264],
        #        [0.03745825, -0.03744098],
        #        [0.03813783, -0.03812835],
        #        [0.03852304, -0.03850771]], dtype=float32)
        # self.sampler.sample_sequence(188855)['robot1_gripper_qpos'].shape
        # (10, 2)
        # self.replay_buffer['robot1_gripper_qpos'][:10]
        # array([[0.020833, -0.020833],
        #        [0.02109402, -0.02115064],
        #        [0.0231193, -0.02303653],
        #        [0.02637962, -0.0263418],
        #        [0.0304168, -0.03037665],
        #        [0.0340717, -0.03403934],
        #        [0.0362446, -0.03621264],
        #        [0.03745825, -0.03744098],
        #        [0.03813783, -0.03812835],
        #        [0.03852304, -0.03850771]], dtype=float32)
        # print(len(self.sampler),'before')

        for traj_num in range(5):
            if traj_num==0:

                # if not self.aux_data['obs']:
                #     for key in obs_dict.keys():
                #         self.aux_data['obs'][key] = obs_dict[key]
                # else:
                #     for key in obs_dict.keys():
                #         self.aux_data['obs'][key]=torch.cat((self.aux_data['obs'][key],obs_dict[key]),dim=0)
                # if self.aux_data['actions'] is None:
                #     self.aux_data['actions'] = action
                # else:
                #     self.aux_data['actions']=torch.cat((self.aux_data['actions'],action),dim=0)
                # # self.aux_data = self.aux_data + data_list
                #

                for key in list(self.replay_buffer.keys()):
                    self.replay_buffer[key].resize( self.replay_buffer[key].shape[0], *self.replay_buffer[key].shape[1:])

                    # self.replay_buffer[key][current_episode_len:]=obs_dict[key].reshape(-1,*obs_dict[key].shape[2:])   # traj_len,n_obs_step
                    # self.replay_buffer[key][:] = obs_dict[key]

                self.replay_buffer['action'].resize( self.replay_buffer['action'].shape[0], *self.replay_buffer['action'].shape[1:])

                # array = self.replay_buffer['/meta/episode_ends']
                if not self.replay_buffer['/meta'].get('episode_ends_new'):
                    self.replay_buffer['/meta'].create('episode_ends_new', shape=(1,),dtype='int64')
                else:
                    self.replay_buffer['/meta']['episode_ends_new'].resize(1, )
                # self.replay_buffer['/meta']['episode_ends_new'][-1] =   traj_len
                
                self.replay_buffer['/meta']['episode_ends'] = self.replay_buffer['/meta']['episode_ends_new']

            else:

                data=data_list[traj_num]
                obs_dict=copy.deepcopy(data['obs'])
                action=copy.deepcopy(data['actions'])   # traj,20
                # action=action.repeat(50,1)   # traj,20
                
                traj_len=action.shape[0]

                # if not self.aux_data['obs']:
                #     for key in obs_dict.keys():
                #         self.aux_data['obs'][key] = obs_dict[key]
                # else:
                #     for key in obs_dict.keys():
                #         self.aux_data['obs'][key]=torch.cat((self.aux_data['obs'][key],obs_dict[key]),dim=0)
                # if self.aux_data['actions'] is None:
                #     self.aux_data['actions'] = action
                # else:
                #     self.aux_data['actions']=torch.cat((self.aux_data['actions'],action),dim=0)
                # # self.aux_data = self.aux_data + data_list
                #
                for key in self.rgb_keys:
                    obs_dict[key]=obs_dict[key]*255
                    obs_dict[key]=obs_dict[key].astype(np.uint8)
                    # if obs_dict[key].shape[2]!=3:
                    #     assert False,'dimension false'
                    obs_dict[key]=np.moveaxis(obs_dict[key],1,-1)
                # action=action

                current_episode_len = self.replay_buffer['/meta']['episode_ends'][-1]

                # len_obs_new_data=obs_dict['robot0_eye_in_hand_image'].shape[0]*self.n_obs_steps

                current_episode_len=self.replay_buffer['action'].shape[0]
                # len_action_new_data=action.shape[0]*action.shape[1]

                for key in obs_dict.keys():

                    self.replay_buffer[key].resize(current_episode_len+traj_len,*self.replay_buffer[key].shape[1:])

                    # self.replay_buffer[key][current_episode_len:]=obs_dict[key].reshape(-1,*obs_dict[key].shape[2:])   # traj_len,n_obs_step
                    self.replay_buffer[key][current_episode_len:]=obs_dict[key]



                self.replay_buffer['action'].resize(current_episode_len+traj_len,*self.replay_buffer['action'].shape[1:])
                self.replay_buffer['action'][current_episode_len:] = action

                array = self.replay_buffer['/meta/episode_ends']

                self.replay_buffer['/meta']['episode_ends_new'].resize(self.replay_buffer.n_episodes+1,)
                self.replay_buffer['/meta']['episode_ends_new'][:-1]=array[:]

                self.replay_buffer['/meta']['episode_ends_new'][-1]=array[-1]+traj_len
                self.replay_buffer['/meta']['episode_ends']=self.replay_buffer['/meta']['episode_ends_new']

        key_first_k = dict()
        if self.n_obs_steps is not None:
            # only take first k obs from images
            for key in self.rgb_keys + self.lowdim_keys:
                key_first_k[key] = self.n_obs_steps
        val_mask = get_val_mask(
            n_episodes=self.replay_buffer.n_episodes,
            val_ratio=self.val_ratio,
            seed=self.seed)
        train_mask = ~val_mask
        sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=self.horizon,
            pad_before=self.pad_before,
            pad_after=self.pad_after,
            episode_mask=train_mask,
            key_first_k=key_first_k)

        self.sampler = sampler
        # print(len(self.sampler),'after')
        




    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        threadpool_limits(1)
        # if idx>=len(self.sampler):
        #     new_idx=idx - len(self.sampler)
        #     obs_dict = dict()
        #     for key in self.aux_data['obs'].keys():
        #         obs_dict[key] = self.aux_data['obs'][key][new_idx]
        #     action=self.aux_data['action'][new_idx]
        #     torch_data = {
        #         'obs': obs_dict,
        #         'action':action
        #     }
        #     return torch_data
        #
        # else:
        data = self.sampler.sample_sequence(idx)
        # sampler.sample_sequence(0)['action'].shape
        # (16, 20)
        # sampler.sample_sequence(0)['robot0_eef_quat'].shape
        # (16, 4)
        # sampler.sample_sequence(0)['robot0_eye_in_hand_image'].shape
        # (16, 84, 84, 3)

        # to save RAM, only return first n_obs_steps of OBS
        # since the rest will be discarded anyway.
        # when self.n_obs_steps is None
        # this slice does nothing (takes all)
        T_slice = slice(self.n_obs_steps)

        obs_dict = dict()
        for key in self.rgb_keys:
            # move channel last to channel first
            # T,H,W,C
            # convert uint8 image to float32
            obs_dict[key] = np.moveaxis(data[key][T_slice],-1,1
                ).astype(np.float32) / 255.
            # T,C,H,W
            del data[key]
        for key in self.lowdim_keys:
            obs_dict[key] = data[key][T_slice].astype(np.float32)
            del data[key]

        torch_data = {
            'obs': dict_apply(obs_dict, torch.from_numpy),
            'action': torch.from_numpy(data['action'].astype(np.float32))
        }
        return torch_data



def _convert_actions(raw_actions, abs_action, rotation_transformer):
    actions = raw_actions
    if abs_action:
        is_dual_arm = False
        if raw_actions.shape[-1] == 14:
            # dual arm
            raw_actions = raw_actions.reshape(-1,2,7)
            is_dual_arm = True

        pos = raw_actions[...,:3]
        rot = raw_actions[...,3:6]
        gripper = raw_actions[...,6:]
        rot = rotation_transformer.forward(rot)
        raw_actions = np.concatenate([
            pos, rot, gripper
        ], axis=-1).astype(np.float32)
    
        if is_dual_arm:
            raw_actions = raw_actions.reshape(-1,20)
        actions = raw_actions
    return actions


def _convert_robomimic_to_replay(store, shape_meta, dataset_path, abs_action, rotation_transformer, 
        n_workers=None, max_inflight_tasks=None,extra=False):
    if n_workers is None:
        n_workers = multiprocessing.cpu_count()
    if max_inflight_tasks is None:
        max_inflight_tasks = n_workers * 5

    # parse shape_meta
    rgb_keys = list()
    lowdim_keys = list()
    # construct compressors and chunks
    obs_shape_meta = shape_meta['obs']
    for key, attr in obs_shape_meta.items():
        shape = attr['shape']
        type = attr.get('type', 'low_dim')
        if type == 'rgb':
            rgb_keys.append(key)
        elif type == 'low_dim':
            lowdim_keys.append(key)
    
    root = zarr.group(store)
    data_group = root.require_group('data', overwrite=True)
    meta_group = root.require_group('meta', overwrite=True)

    with h5py.File(dataset_path) as file:
        # count total steps
        demos = file['data']
        episode_ends = list()
        prev_end = 0
        demo_keys=list(demos.keys())

        for i in range(len(demos)):
            # demo = demos[f'demo_{i}']
            demo = demos[demo_keys[i]]

            episode_length = demo['actions'].shape[0]
            episode_end = prev_end + episode_length
            prev_end = episode_end
            episode_ends.append(episode_end)
        n_steps = episode_ends[-1]
        episode_starts = [0] + episode_ends[:-1]
        _ = meta_group.array('episode_ends', episode_ends, 
            dtype=np.int64, compressor=None, overwrite=True)

        # save lowdim data
        for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
            # print(key,'key')
            # print(demos['actions'].shape)
            data_key = 'obs/' + key
            if key == 'action':
                data_key = 'actions'
            this_data = list()
            for i in range(len(demos)):
                # demo = demos[f'demo_{i}']
                demo = demos[demo_keys[i]]
                # print(demo[data_key][:].shape,data_key)
                this_data.append(demo[data_key][:].astype(np.float32))
            # breakpoint()
            this_data = np.concatenate(this_data, axis=0)
            # print(shape_meta['obs'])
            # print(this_data.shape, 'first this_data.shape')
            if not extra:
                if key == 'action':
                    this_data = _convert_actions(
                        raw_actions=this_data,
                        abs_action=abs_action,
                        rotation_transformer=rotation_transformer
                    )
                    # print(this_data.shape,'this_data.shape')
                    # print(n_steps,'n_steps')
                    # print(tuple(shape_meta['action']['shape']))
                    assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape'])
                else:
                    assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
            _ = data_group.array(
                name=key,
                data=this_data,
                shape=this_data.shape,
                chunks=this_data.shape,
                compressor=None,
                dtype=this_data.dtype
            )

        def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
            try:
                zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
                # make sure we can successfully decode
                _ = zarr_arr[zarr_idx]
                return True
            except Exception as e:
                return False
        
        with tqdm(total=n_steps*len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
            # one chunk per thread, therefore no synchronization needed
            with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
                futures = set()
                for key in rgb_keys:
                    data_key = 'obs/' + key
                    shape = tuple(shape_meta['obs'][key]['shape'])
                    c,h,w = shape
                    this_compressor = Jpeg2k(level=50)
                    img_arr = data_group.require_dataset(
                        name=key,
                        shape=(n_steps,h,w,c),
                        chunks=(1,h,w,c),
                        compressor=this_compressor,
                        dtype=np.uint8
                    )
                    for episode_idx in range(len(demos)):
                        # demo = demos[f'demo_{episode_idx}']
                        demo = demos[demo_keys[episode_idx]]
                        hdf5_arr = demo['obs'][key]
                        for hdf5_idx in range(hdf5_arr.shape[0]):
                            if len(futures) >= max_inflight_tasks:
                                # limit number of inflight tasks
                                completed, futures = concurrent.futures.wait(futures, 
                                    return_when=concurrent.futures.FIRST_COMPLETED)
                                for f in completed:
                                    if not f.result():
                                        raise RuntimeError('Failed to encode image!')
                                pbar.update(len(completed))

                            zarr_idx = episode_starts[episode_idx] + hdf5_idx
                            futures.add(
                                executor.submit(img_copy, 
                                    img_arr, zarr_idx, hdf5_arr, hdf5_idx))
                completed, futures = concurrent.futures.wait(futures)
                for f in completed:
                    if not f.result():
                        raise RuntimeError('Failed to encode image!')
                pbar.update(len(completed))

    replay_buffer = ReplayBuffer(root)
    return replay_buffer

def normalizer_from_stat(stat):
    max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
    scale = np.full_like(stat['max'], fill_value=1/max_abs)
    offset = np.zeros_like(stat['max'])
    return SingleFieldLinearNormalizer.create_manual(
        scale=scale,
        offset=offset,
        input_stats_dict=stat
    )
