#-*- coding:utf-8 -*-

import collections
import sys 
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from tqdm import tqdm
import numpy as np
import torch
import h5py
import zarr

from dataset.tasks import Lift, ControlType
from dataset.robomimic_image__dataset import _convert_actions
from diffusion_policy.rotation_transformer import RotationTransformer 
from diffusion_policy.imagecodecs_numcodecs import register_codecs, Jpeg2k
from diffusion_policy.common.replay_buffer import ReplayBuffer
from diffusion_policy.gym_util.video_recording_wrapper import VideoRecordingWrapper
from diffusion_policy.real_world.video_recorder import VideoRecorder
from diffusion_policy.robomimic_image_wrapper import RobomimicImageWrapper
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils
import wandb.sdk.data_types.video as wv
import concurrent.futures
import multiprocessing
import pathlib

import matplotlib.pyplot as plt

IGNORE_COMPRESSOR = False

def create_env(
        task,
        seed:int, 
        enable_render:bool=True, 
        output_dir:str="results", 
        render_obs_key:str='agentview_image',
        fps:int=10,
        crf:int=22
    ):
    if 1:
        shape_meta = task.get_shape_meta()
        env_meta = FileUtils.get_env_metadata_from_dataset(task.dataset_path)
        env_meta['env_kwargs']['controller_configs']['control_delta'] = False
        # disable object state observation
        env_meta['env_kwargs']['use_object_obs'] = False
        modality_mapping = collections.defaultdict(list)
        for key, attr in shape_meta['obs'].items():
            modality_mapping[attr.get('type', 'low_dim')].append(key)
        ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)

        _env = EnvUtils.create_env_from_metadata(
            env_meta=env_meta,
            render=False, 
            render_offscreen=enable_render,
            use_image_obs=enable_render, 
        )
        # Robosuite's hard reset causes excessive memory consumption.
        # Disabled to run more envs.
        # https://github.com/ARISE-Initiative/robosuite/blob/92abf5595eddb3a845cd1093703e5a3ccd01e77e/robosuite/environments/base.py#L247-L248
        _env.hard_reset = False
        robosuite_fps = 20
        steps_per_render = max(robosuite_fps // fps, 1)

        env = VideoRecordingWrapper(
            RobomimicImageWrapper(
                env=_env,
                shape_meta=shape_meta,
                init_state=None,
                render_obs_key=render_obs_key
            ),
            video_recoder=VideoRecorder.create_h264(
            fps=fps,
            codec='h264',
            input_pix_fmt='rgb24',
            crf=crf,
            thread_type='FRAME',
            thread_count=1
            ),
            file_path=None,
            steps_per_render=steps_per_render
        )

        env.video_recoder.stop()
        env.file_path = None
        if enable_render:
            filename = pathlib.Path(output_dir).joinpath(
                'media', wv.util.generate_id() + ".mp4")
            filename.parent.mkdir(parents=False, exist_ok=True)
            filename = str(filename)
            env.file_path = filename

        # switch to seed reset
        assert isinstance(env.env, RobomimicImageWrapper)
        env.env.env.init_state = None
        env.seed(seed)
        return env

def main():
    dataset_path = sys.argv[1]
    store = zarr.MemoryStore()
    CONTROL_TYPE = ControlType.IMAGE

    task = Lift(ctype=CONTROL_TYPE)
    rgb_keys = list()
    lowdim_keys = list()
    shape_meta = task.get_shape_meta()
    obs_shape_meta = shape_meta['obs']

    for key, attr in obs_shape_meta.items():
        shape = attr['shape']
        type = attr.get('type', 'low_dim')
        if type == 'rgb':
            rgb_keys.append(key)
        elif type == 'low_dim':
            lowdim_keys.append(key)

    root = zarr.group(store)
    data_group = root.require_group('data', overwrite=True)
    meta_group = root.require_group('meta', overwrite=True)

    n_workers = multiprocessing.cpu_count()
    max_inflight_tasks = n_workers * 5
    print("Num Workers:", n_workers)

    # Load Datas
    with h5py.File(dataset_path) as file:
        # count total steps
        demos = file['data']
        episode_ends = list()
        prev_end = 0
        for i in range(len(demos)):
            demo = demos[f'demo_{i}']
            episode_length = demo['actions'].shape[0]
            episode_end = prev_end + episode_length
            prev_end = episode_end
            episode_ends.append(episode_end)
        # episode_ends = episode_ends[:100]
        n_steps = episode_ends[-1]
        episode_starts = [0] + episode_ends[:-1]
        _ = meta_group.array('episode_ends', episode_ends, dtype=np.int64, compressor=None, overwrite=True)


        def img_copy(zarr_arr, zarr_idx, hdf5_arr, hdf5_idx):
            try:
                zarr_arr[zarr_idx] = hdf5_arr[hdf5_idx]
                # make sure we can successfully decode
                _ = zarr_arr[zarr_idx]
                return True
            except Exception as e:
                return False
            
        img_arrays = []

        # Load Iamges
        with tqdm(total=int(n_steps*len(rgb_keys)), desc="Loading image data", mininterval=1.0) as pbar:
            # one chunk per thread, therefore no synchronization needed
            with concurrent.futures.ThreadPoolExecutor(max_workers=n_workers) as executor:
                futures = set()
                for key in rgb_keys:
                    data_key = 'obs/' + key
                    shape = tuple(shape_meta['obs'][key]['shape'])
                    c,h,w = shape
                    this_compressor = Jpeg2k(level=50) if not IGNORE_COMPRESSOR else None
                    img_arr = data_group.require_dataset(
                        name=key,
                        shape=(n_steps,h,w,c),
                        chunks=(1,h,w,c),
                        compressor=this_compressor,
                        dtype=np.uint8
                    )
                    img_arrays.append(img_arr)
                    for episode_idx in range(len(demos)):
                        demo = demos[f'demo_{episode_idx}']
                        hdf5_arr = demo['obs'][key]
                        for hdf5_idx in range(hdf5_arr.shape[0]):
                            if len(futures) >= max_inflight_tasks:
                                # limit number of inflight tasks
                                completed, futures = concurrent.futures.wait(futures, 
                                    return_when=concurrent.futures.FIRST_COMPLETED)
                                for f in completed:
                                    if not f.result():
                                        raise RuntimeError('Failed to encode image!')
                                pbar.update(len(completed))

                            zarr_idx = episode_starts[episode_idx] + hdf5_idx
                            futures.add(
                                executor.submit(img_copy, 
                                    img_arr, zarr_idx, hdf5_arr, hdf5_idx))
                completed, futures = concurrent.futures.wait(futures)
                for f in completed:
                    if not f.result():
                        raise RuntimeError('Failed to encode image!')
                pbar.update(len(completed))

    replay_buffer = ReplayBuffer(root)
    print(img_arrays[0].shape, len(img_arrays))
    arr = np.asarray(img_arrays)
    print("MAX:", np.max(arr))
    print("MIN:", np.min(arr))
    print("MEAN:", np.mean(arr))
    print("STD:", np.std(arr))
    plt.imshow(img_arrays[1][0])
    plt.show()

    seed = 100000
    env = create_env(task, seed=seed)
    obs = env.reset()
    obs_deque = collections.deque(
        [obs] * task.obs_horizon, 
        maxlen=task.obs_horizon
    )

    img_arrays = []
    for n in tqdm(range(10)):
        action = np.random.rand(7,)
        for i in range(len(action)):
            obs, reward, done, info = env.step(action)
            obs_deque.append(obs)
        images = np.stack([x['agentview_image'] for x in obs_deque])
        img_arrays.append(images[0])
    print(img_arrays[0].shape, len(img_arrays))
    arr = np.asarray(img_arrays) *255
    print("MAX:", np.max(arr))
    print("MIN:", np.min(arr))
    print("MEAN:", np.mean(arr))
    print("STD:", np.std(arr))




if __name__ == '__main__':
    main()