from typing import Dict, List
import torch
import numpy as np
import h5py
from tqdm import tqdm
import copy
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.dataset.base_dataset import BaseLowdimDataset, LinearNormalizer
from diffusion_policy.model.common.normalizer import LinearNormalizer, SingleFieldLinearNormalizer
from diffusion_policy.model.common.rotation_transformer import RotationTransformer
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.common.sampler import (
    SequenceSampler, get_val_mask, downsample_mask)
from diffusion_policy.common.normalize_util import (
    robomimic_abs_action_only_normalizer_from_stat,
    robomimic_abs_action_only_dual_arm_normalizer_from_stat,
    get_identity_normalizer_from_stat,
    array_to_stats
)


class RobomimicReplayLowdimDataset(BaseLowdimDataset):
    def __init__(self,
                 dataset_path: str,
                 horizon=1,
                 pad_before=0,
                 pad_after=0,
                 n_obs_steps=2,
                 obs_keys: List[str] = [
                     'object',
                     'robot0_eef_pos',
                     'robot0_eef_quat',
                     'robot0_gripper_qpos'],
                 abs_action=False,
                 rotation_rep='rotation_6d',
                 use_legacy_normalizer=False,
                 seed=42,
                 val_ratio=0.0,
                 max_train_episodes=None,
                 subdataset=False
                 ):
        obs_keys = list(obs_keys)
        # ['object', 'robot0_eef_pos', 'robot0_eef_quat', 'robot0_gripper_qpos', 'robot1_eef_pos', 'robot1_eef_quat',
        #  'robot1_gripper_qpos']

        rotation_transformer = RotationTransformer(
            from_rep='axis_angle', to_rep=rotation_rep)

        replay_buffer = ReplayBuffer.create_empty_numpy()
        with h5py.File(dataset_path) as file:
            demos = file['data']

            # demo['obs']['object'].shape
            # (714, 41)
            # demo['obs']['robot0_eef_pos'].shape
            # (714, 3)
            # demo['obs']['robot0_gripper_qpos'].shape
            # (714, 2)
            # demo['obs']['robot0_eef_quat'].shape
            # (714, 4)

            for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"):
                demo = demos[f'demo_{i}']
                episode = _data_to_obs(
                    raw_obs=demo['obs'],
                    raw_actions=demo['actions'][:].astype(np.float32),
                    obs_keys=obs_keys,
                    abs_action=abs_action,
                    rotation_transformer=rotation_transformer)
                
                replay_buffer.add_episode(episode)
                self.obs_shape=episode['obs'].shape



        if subdataset:
            ori_episode_ends=self.replay_buffer.episode_ends[:]
            ori_dataset_len=len(ori_episode_ends)

            subdataset_episode_start_ends=[(0,ori_episode_ends[1])]

            subdataset_list=np.linspace(2,ori_dataset_len-1,ori_dataset_len//10).astype(int)
            subdataset_episode_start_ends.extend([(ori_episode_ends[i-1],ori_episode_ends[i]) for i in subdataset_list])


            episode_len=[end-start for start,end in subdataset_episode_start_ends]
            from itertools import accumulate

            new_episode_ends=list(accumulate(episode_len))

            self.replay_buffer.meta['episode_ends']=np.array(new_episode_ends)
            for key in self.replay_buffer.data.keys():
                subdataset_key_arraylist=[self.replay_buffer[key][start:end] for start,end in subdataset_episode_start_ends]
                subdataset_key_array=np.concatenate(subdataset_key_arraylist,axis=0)
                self.replay_buffer.data[key]=subdataset_key_array

        val_mask = get_val_mask(
            n_episodes=replay_buffer.n_episodes,
            val_ratio=val_ratio,
            seed=seed)
        train_mask = ~val_mask
        train_mask = downsample_mask(
            mask=train_mask,
            max_n=max_train_episodes,
            seed=seed)

        sampler = SequenceSampler(
            replay_buffer=replay_buffer,
            sequence_length=horizon,
            pad_before=pad_before,
            pad_after=pad_after,
            episode_mask=train_mask)

        self.replay_buffer = replay_buffer
        self.ori_replay_buffer = replay_buffer
        self.sampler = sampler
        self.abs_action = abs_action
        self.train_mask = train_mask
        self.horizon = horizon
        self.pad_before = pad_before
        self.pad_after = pad_after
        self.use_legacy_normalizer = use_legacy_normalizer
        self.obs_keys = obs_keys
        self.rotation_transformer=rotation_transformer
        self.val_ratio=val_ratio
        self.seed=seed
        self.max_train_episodes=max_train_episodes
        # self.replay_buffer['obs'].shape
        # (195800, 59)
        

        # self.ori_len_obs_buffer = self.replay_buffer['robot0_eye_in_hand_image'].shape[0]

        # len_obs_new_data=obs_dict['robot0_eye_in_hand_image'].shape[0]*self.n_obs_steps

        # self.ori_len_action_buffer = self.replay_buffer['action'].shape[0]
        self.ori_episode_num=self.replay_buffer.n_episodes
        # self.ori_episode_ends = self.replay_buffer['/meta/episode_ends']


        # if self.ori_episode_num<100:
        self.extra_episode_num=450-self.ori_episode_num
        # else:
        #     self.extra_episode_num=self.ori_episode_num//4

        self.extra_episode=[]


    def finetune_data(self,traj_list):
        """
        item from traj_list should be numpy array
        """


        self.extra_episode.extend(traj_list)
        # while len(self.extra_episode)>self.extra_episode_num:
        current_extra_episode_num=len(self.extra_episode)
        if current_extra_episode_num>self.extra_episode_num:
            delete_num=current_extra_episode_num-self.extra_episode_num
            self.extra_episode=self.extra_episode[delete_num:]
        self.replay_buffer=self.ori_replay_buffer



        for demo in self.extra_episode:
            # demo = traj_list[i]
            demo['action']=demo['action'].cpu().numpy().astype(np.float32)
            # episode = _data_to_obs(
            #     raw_obs=demo['obs'],
            #     raw_actions=demo['action'][:].cpu().numpy().astype(np.float32),
            #     obs_keys=self.obs_keys,
            #     abs_action=self.abs_action,
            #     rotation_transformer=self.rotation_transformer)
            self.replay_buffer.add_episode(demo)
            self.obs_shape=demo['obs'].shape
        val_mask = get_val_mask(
            n_episodes=self.replay_buffer.n_episodes,
            val_ratio=self.val_ratio,
            seed=self.seed)
        train_mask = ~val_mask
        train_mask = downsample_mask(
            mask=train_mask,
            max_n=self.max_train_episodes,
            seed=self.seed)

        sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=self.horizon,
            pad_before=self.pad_before,
            pad_after=self.pad_after,
            episode_mask=train_mask)

        self.sampler = sampler
        self.train_mask = train_mask



    def get_validation_dataset(self):
        val_set = copy.copy(self)
        val_set.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=self.horizon,
            pad_before=self.pad_before,
            pad_after=self.pad_after,
            episode_mask=~self.train_mask
        )
        val_set.train_mask = ~self.train_mask
        return val_set

    def get_normalizer(self, **kwargs) -> LinearNormalizer:
        normalizer = LinearNormalizer()

        # action
        stat = array_to_stats(self.replay_buffer['action'])
        if self.abs_action:
            if stat['mean'].shape[-1] > 10:
                # dual arm
                this_normalizer = robomimic_abs_action_only_dual_arm_normalizer_from_stat(stat)
            else:
                this_normalizer = robomimic_abs_action_only_normalizer_from_stat(stat)

            if self.use_legacy_normalizer:
                this_normalizer = normalizer_from_stat(stat)
        else:
            # already normalized
            this_normalizer = get_identity_normalizer_from_stat(stat)
        normalizer['action'] = this_normalizer

        # aggregate obs stats
        obs_stat = array_to_stats(self.replay_buffer['obs'])

        normalizer['obs'] = normalizer_from_stat(obs_stat)
        return normalizer

    def get_all_actions(self) -> torch.Tensor:
        return torch.from_numpy(self.replay_buffer['action'])

    def __len__(self):
        return len(self.sampler)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        data = self.sampler.sample_sequence(idx)
        torch_data = dict_apply(data, torch.from_numpy)
        return torch_data


def normalizer_from_stat(stat):
    max_abs = np.maximum(stat['max'].max(), np.abs(stat['min']).max())
    scale = np.full_like(stat['max'], fill_value=1 / max_abs)
    offset = np.zeros_like(stat['max'])
    return SingleFieldLinearNormalizer.create_manual(
        scale=scale,
        offset=offset,
        input_stats_dict=stat
    )


def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action, rotation_transformer):
    obs = np.concatenate([
        raw_obs[key] for key in obs_keys
    ], axis=-1).astype(np.float32)
    # print(obs.shape)
    # print(obs_keys)
    
    if abs_action:
        is_dual_arm = False
        if raw_actions.shape[-1] == 14:
            # dual arm
            raw_actions = raw_actions.reshape(-1, 2, 7)
            is_dual_arm = True

        pos = raw_actions[..., :3]
        rot = raw_actions[..., 3:6]
        gripper = raw_actions[..., 6:]
        rot = rotation_transformer.forward(rot)
        raw_actions = np.concatenate([
            pos, rot, gripper
        ], axis=-1).astype(np.float32)

        if is_dual_arm:
            raw_actions = raw_actions.reshape(-1, 20)

    data = {
        'obs': obs,
        'action': raw_actions
    }
    return data
