""" Real world dataset loader.

NOTE: The dataset keys are hardcoded for our native real-world dataset in hdf5 format,
      double check before using other datasets.
"""

import torch
import numpy as np
import h5py
import zarr
import concurrent.futures
import pdb

from typing import Dict
from tqdm import tqdm
from collections import defaultdict

from cleandiffuser.dataset.imagecodecs import register_codecs, Jpeg2k
from cleandiffuser.dataset.base_dataset import BaseDataset
from cleandiffuser.dataset.replay_buffer import ReplayBuffer
from cleandiffuser.dataset.dataset_utils import SequenceSampler, dict_apply

register_codecs()


class MinMaxNormalizer:
    """
        normalizes data through maximum and minimum expansion.
    """

    def __init__(self, X):
        X = X.reshape(-1, X.shape[-1]).astype(np.float32)
        self.min, self.max = np.min(X, axis=0), np.max(X, axis=0)
        self.range = self.max - self.min
        if np.any(self.range == 0):
            self.range = self.max - self.min
            print("Warning: Some features have the same min and max value. These will be set to 0.")
            self.range[self.range == 0] = 1

    def normalize(self, x):
        x = x.astype(np.float32)
        # nomalize to [0,1]
        nx = (x - self.min) / self.range
        # normalize to [-1, 1]
        nx = nx * 2 - 1
        return nx

    def unnormalize(self, x):
        x = x.astype(np.float32)
        nx = (x + 1) / 2
        x = nx * self.range + self.min
        return x


class ImageNormalizer:
    """ Normalizes image data from range [0, 255] to [-1, 1].
    """

    def __init__(self):
        pass

    def normalize(self, x):
        return ((x / 255.0) * 2.0) - 1.0

    def unnormalize(self, x):
        return ((x + 1.0) / 2.0) * 255.0


class DepthNormalizer:
    """ Normalizes depth data from range [0, max_depth] to [-1, 1].
    """

    def __init__(self, max_depth=1800.0):
        self.max_depth = max_depth

    def normalize(self, x):
        return ((x * 2.0) / self.max_depth) - 1.0

    def unnormalize(self, x):
        return ((x + 1.0) * self.max_depth) / 2.0


class RealWorldDataset(BaseDataset):
    def __init__(self,
                 dataset_dir,
                 horizon=1,
                 pad_before=0,
                 pad_after=0,
                 obs_keys=('ee_pose', 'ee_quat', 'gripper_state'),
                 abs_action=False,
                 ):
        super().__init__()

        self.replay_buffer = ReplayBuffer.create_empty_numpy()
        with h5py.File(dataset_dir) as file:
            demos = file
            for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"):
                demo = demos[f'demo_{i}']
                episode = _data_to_obs(
                    raw_obs=demo['obs'],
                    raw_actions=demo['actions'][:].astype(np.float32),
                    obs_keys=obs_keys,
                    abs_action=abs_action)
                self.replay_buffer.add_episode(episode)

        self.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=horizon,
            pad_before=pad_before,
            pad_after=pad_after)

        self.horizon = horizon
        self.pad_before = pad_before
        self.pad_after = pad_after
        self.abs_action = abs_action
        self.normalizer = self.get_normalizer()

    def undo_transform_action(self, action):
        raw_shape = action.shape
        if raw_shape[-1] == 20:
            # dual arm
            action = action.reshape(-1, 2, 10)

        d_rot = action.shape[-1] - 4
        pos = action[..., :3]
        rot = action[..., 3:3 + d_rot]
        gripper = action[..., [-1]]

        uaction = np.concatenate([
            pos, rot, gripper
        ], axis=-1)

        if raw_shape[-1] == 20:
            # dual arm
            uaction = uaction.reshape(*raw_shape[:-1], 14)

        return uaction

    def get_normalizer(self):
        if self.abs_action:
            state_normalizer = MinMaxNormalizer(self.replay_buffer['obs'][:])  # (N, obs_dim)
            action_normalizer = MinMaxNormalizer(self.replay_buffer['action'][:])  # (N, action_dim)
        else:
            state_normalizer = MinMaxNormalizer(self.replay_buffer['obs'][:])  # (N, obs_dim)
            action_normalizer = MinMaxNormalizer(self.replay_buffer['action'][:])  # (N, action_dim)
        return {
            "obs": {
                "state": state_normalizer
            },
            "action": action_normalizer
        }

    def sample_to_data(self, sample):
        state = sample['obs'].astype(np.float32)
        state = self.normalizer['obs']['state'].normalize(state)

        action = sample['action'].astype(np.float32)
        action = self.normalizer['action'].normalize(action)
        data = {
            'obs': {
                'state': state
            },
            'action': action,
        }
        return data

    def __str__(self) -> str:
        return f"Keys: {self.replay_buffer.keys()} Steps: {self.replay_buffer.n_steps} Episodes: {self.replay_buffer.n_episodes}"

    def __len__(self) -> int:
        return len(self.sampler)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.sampler.sample_sequence(idx)
        data = self.sample_to_data(sample)
        torch_data = dict_apply(data, torch.tensor)
        return torch_data


def _data_to_obs(raw_obs, raw_actions, obs_keys, abs_action):
    raw_obs_list = [
        np.array(raw_obs[key]) for key in obs_keys
    ]
    raw_obs_list = [obs.reshape(obs.shape[0], 1) for obs in raw_obs_list if len(obs.shape) < 2] # important for gripper state
    obs = np.concatenate(raw_obs_list, axis=-1).astype(np.float32)

    if abs_action:
        is_dual_arm = False
        if raw_actions.shape[-1] == 14:
            # dual arm
            raw_actions = raw_actions.reshape(-1, 2, 7)
            is_dual_arm = True

        pos = raw_actions[..., :3]
        rot = raw_actions[..., 3:6]
        gripper = raw_actions[..., 6:]

        raw_actions = np.concatenate([
            pos, rot, gripper
        ], axis=-1).astype(np.float32)

        if is_dual_arm:
            raw_actions = raw_actions.reshape(-1, 20)

    data = {
        'obs': obs,
        'action': raw_actions
    }
    return data


class RealWorldImageDataset(BaseDataset):
    def __init__(self,
                 dataset_dir,
                 shape_meta: dict,
                 n_obs_steps=None,
                 horizon=1,
                 pad_before=0,
                 pad_after=0,
                 abs_action=False,
                 ):
        super().__init__()

        self.replay_buffer = _convert_data_to_replay(
            store=zarr.MemoryStore(),
            shape_meta=shape_meta,
            dataset_path=dataset_dir,
            abs_action=abs_action)

        rgb_keys = list()
        lowdim_keys = list()
        obs_shape_meta = shape_meta['obs']
        for key, attr in obs_shape_meta.items():
            type = attr.get('type', 'low_dim')
            if type == 'rgb':
                rgb_keys.append(key)
            elif type == 'low_dim':
                lowdim_keys.append(key)

        # check if proper keys have been found and imported
        print(f"Lowdim keys: {lowdim_keys} RGB keys: {rgb_keys}")

        key_first_k = dict()
        if n_obs_steps is not None:
            # only take first k obs from images
            for key in rgb_keys + lowdim_keys:
                key_first_k[key] = n_obs_steps

        self.sampler = SequenceSampler(
            replay_buffer=self.replay_buffer,
            sequence_length=horizon,
            pad_before=pad_before,
            pad_after=pad_after,
            key_first_k=key_first_k
        )

        self.shape_meta = shape_meta
        self.rgb_keys = rgb_keys
        self.lowdim_keys = lowdim_keys
        self.abs_action = abs_action
        self.horizon = horizon
        self.pad_before = pad_before
        self.pad_after = pad_after
        self.n_obs_steps = n_obs_steps

        self.normalizer = self.get_normalizer()

    def get_normalizer(self):
        normalizer = defaultdict(dict)
        for key in self.lowdim_keys:
            normalizer['obs'][key] = MinMaxNormalizer(self.replay_buffer[key][:])
        for key in self.rgb_keys:
            if key == "depth": # depth image
                normalizer['obs'][key] = DepthNormalizer(max_depth=1800.0)
            elif key == "agentview": # rgb image
                normalizer['obs'][key] = ImageNormalizer()
            elif key == "tactile": # tactile image
                normalizer['obs'][key] = ImageNormalizer()
            else:
                raise ValueError(f"Unknown rgb key {key} for normalizer")
        normalizer['action'] = MinMaxNormalizer(self.replay_buffer['action'][:])

        return normalizer

    def __str__(self) -> str:
        return f"Keys: {self.replay_buffer.keys()} Steps: {self.replay_buffer.n_steps} Episodes: {self.replay_buffer.n_episodes}"

    def __len__(self) -> int:
        return len(self.sampler)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.sampler.sample_sequence(idx)

        # obs
        # to save RAM, only return first n_obs_steps of OBS
        # since the rest will be discarded anyway.
        # when self.n_obs_steps is None
        # this slice does nothing (takes all)
        T_slice = slice(self.n_obs_steps)

        obs_dict = dict()
        for key in self.rgb_keys:
            # move channel last to channel first
            # T,H,W,C
            # convert uint8 image to float32
            obs_dict[key] = np.moveaxis(sample[key][T_slice], -1, 1
                                        ).astype(np.float32)
            # T,C,H,W
            del sample[key]
            obs_dict[key] = self.normalizer['obs'][key].normalize(obs_dict[key])

        for key in self.lowdim_keys:
            obs_dict[key] = sample[key][T_slice].astype(np.float32)
            del sample[key]
            obs_dict[key] = self.normalizer['obs'][key].normalize(obs_dict[key])

        # action
        action = sample['action'].astype(np.float32)
        action = self.normalizer['action'].normalize(action)

        torch_data = {
            'obs': dict_apply(obs_dict, torch.tensor),
            'action': torch.tensor(action)
        }
        return torch_data

    def undo_transform_action(self, action):
        raw_shape = action.shape
        if raw_shape[-1] == 20:
            # dual arm
            action = action.reshape(-1, 2, 10)

        d_rot = action.shape[-1] - 4
        pos = action[..., :3]
        rot = action[..., 3:3 + d_rot]
        gripper = action[..., [-1]]

        uaction = np.concatenate([
            pos, rot, gripper
        ], axis=-1)

        if raw_shape[-1] == 20:
            # dual arm
            uaction = uaction.reshape(*raw_shape[:-1], 14)

        return uaction


def _convert_actions(raw_actions, abs_action, action_shape):
    actions = raw_actions
    if abs_action:
        is_dual_arm = False
        if raw_actions.shape[-1] == 14:
            # dual arm
            raw_actions = raw_actions.reshape(-1, 2, 7)
            is_dual_arm = True

        pos = raw_actions[..., :3]
        rot = raw_actions[..., 3:6]
        gripper = raw_actions[..., 6:]

        raw_actions = np.concatenate([
            pos, rot, gripper
        ], axis=-1).astype(np.float32)

        if is_dual_arm:
            raw_actions = raw_actions.reshape(-1, 20)
        actions = raw_actions
    return actions[:, :action_shape]


def _convert_data_to_replay(store, shape_meta, dataset_path, abs_action,
                                 n_workers=None, max_inflight_tasks=None):

    import multiprocessing
    if n_workers is None:
        n_workers = multiprocessing.cpu_count()
    if max_inflight_tasks is None:
        max_inflight_tasks = n_workers * 5

    # parse shape_meta
    rgb_keys = list()
    lowdim_keys = list()
    # construct compressors and chunks
    obs_shape_meta = shape_meta['obs']
    for key, attr in obs_shape_meta.items():
        shape = attr['shape']
        type = attr.get('type', 'low_dim')
        if type == 'rgb':
            rgb_keys.append(key)
        elif type == 'low_dim':
            lowdim_keys.append(key)

    print(f"Lowdim keys: {lowdim_keys} RGB keys: {rgb_keys}")

    # create zarr group
    root = zarr.group(store)
    data_group = root.require_group('data', overwrite=True)
    meta_group = root.require_group('meta', overwrite=True)

    with h5py.File(dataset_path) as file:

        # count total steps
        demos = file
        episode_ends = list()
        prev_end = 0
        for i in range(len(demos)):
            demo = demos[f'demo_{i}']
            episode_length = demo['actions'].shape[0]
            episode_end = prev_end + episode_length
            prev_end = episode_end
            episode_ends.append(episode_end)
        n_steps = episode_ends[-1]
        episode_starts = [0] + episode_ends[:-1]
        _ = meta_group.array('episode_ends', episode_ends,
                             dtype=np.int64, compressor=None, overwrite=True)

        # save lowdim data
        for key in tqdm(lowdim_keys + ['action'], desc="Loading lowdim data"):
            data_key = 'obs/' + key
            if key == 'action':
                data_key = 'actions'
            this_data = list()
            for i in range(len(demos)):
                demo = demos[f'demo_{i}']
                raw_data = demo[data_key][:].astype(np.float32)
                if raw_data.ndim == 1:
                    # single value, e.g. gripper state
                    raw_data = raw_data.reshape(-1, 1)
                this_data.append(raw_data)
            this_data = np.concatenate(this_data, axis=0)
            if key == 'action':
                this_data = _convert_actions(
                    raw_actions=this_data,
                    abs_action=abs_action,
                    action_shape=tuple(shape_meta['action']['shape'])[0]
                )
                assert this_data.shape == (n_steps,) + tuple(shape_meta['action']['shape']), f"Expected action shape {shape_meta['action']['shape']}, {n_steps} but got {this_data.shape}"
            else:
                assert this_data.shape == (n_steps,) + tuple(shape_meta['obs'][key]['shape'])
            _ = data_group.array(
                name=key,
                data=this_data,
                shape=this_data.shape,
                chunks=this_data.shape,
                compressor=None,
                dtype=this_data.dtype
            )

        def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
            try:
                zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
                # make sure we can successfully decode
                _ = zarr_arr[zarr_idx]
                return True
            except Exception as e:
                print(f'Exception: {e}')
                return False

        with tqdm(total=n_steps * len(rgb_keys), desc="Loading image data", mininterval=1.0) as pbar:
            # one chunk per thread, therefore no synchronization needed
            with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
                futures = set()
                for key in rgb_keys:
                    data_key = 'obs/' + key
                    shape = tuple(shape_meta['obs'][key]['shape'])
                    c, h, w = shape
                    this_compressor = Jpeg2k(level=50)
                    img_arr = data_group.require_dataset(
                        name=key,
                        shape=(n_steps, h, w, c),
                        chunks=(1, h, w, c),
                        compressor=this_compressor,
                        dtype=np.uint8
                    )
                    for episode_idx in range(len(demos)):
                        demo = demos[f'demo_{episode_idx}']
                        if key == 'agentview':
                            hdf5_arr = demo['obs'][key]['color']
                        elif key == 'tactile':
                            arr_2_img = demo['obs'][key]['finger_left']
                            hdf5_arr = arr_2_img[:, 0, :, :] # take one of two tactile images
                        elif key == "depth":
                            hdf5_arr = demo['obs']['agentview'][key] # depth is under agentview in costumized dataset
                            hdf5_arr = np.expand_dims(hdf5_arr, axis=-1) # add channel dim
                            # repeat depth to make it 3 channel for resnet etc.
                            hdf5_arr = np.repeat(hdf5_arr, 3, axis=-1)
                        else:
                            raise ValueError(f"Unknown key {key} for RGB data")

                        for hdf5_idx in range(hdf5_arr.shape[0]):
                            if len(futures) >= max_inflight_tasks:
                                # limit number of inflight tasks
                                completed, futures = concurrent.futures.wait(futures,
                                                                             return_when=concurrent.futures.FIRST_COMPLETED)
                                for f in completed:
                                    if not f.result():
                                        raise RuntimeError('Failed to encode image!')
                                pbar.update(len(completed))

                            zarr_idx = episode_starts[episode_idx] + hdf5_idx
                            futures.add(
                                executor.submit(img_copy,
                                                img_arr, zarr_idx, hdf5_arr, hdf5_idx))
                completed, futures = concurrent.futures.wait(futures)
                for f in completed:
                    if not f.result():
                        raise RuntimeError('Failed to encode image!')
                pbar.update(len(completed))

    replay_buffer = ReplayBuffer(root)
    return replay_buffer

if __name__ == "__main__":
    path = "data/frankaslide.hdf5"
    dataset = RealWorldDataset(dataset_dir=path)
    print(dataset)



