import json
import numpy as np
from collect_square_exp_data import save_agentview_video
import h5py
import collections
import robomimic.utils.file_utils as FileUtils
import robomimic.utils.env_utils as EnvUtils
import robomimic.utils.obs_utils as ObsUtils
import robosuite.utils.transform_utils as T
import robomimic.utils.tensor_utils as TensorUtils

from matplotlib import pyplot as plt
import os
import os.path as osp
from tqdm import tqdm
import imageio
import cv2

def create_env(env_meta, shape_meta, enable_render=True):
    modality_mapping = collections.defaultdict(list)
    for key, attr in shape_meta['obs'].items():
        modality_mapping[attr.get('type', 'low_dim')].append(key)
    ObsUtils.initialize_obs_modality_mapping_from_dict(modality_mapping)

    env = EnvUtils.create_env_from_metadata(
        env_meta=env_meta,
        render=False, 
        render_offscreen=enable_render,
        use_image_obs=enable_render, 
    )
    return env

def get_obs(obs, obs_keys):
    obs_dict = {}
    for key in obs_keys:
        if key in obs:
            obs_dict[key] = obs[key]
        else:
            print(f"Key {key} not found in observation")
    return obs_dict

def gen_noisy_demos(file_path=None, save_noisy_demos=True, output_path=None):
    shape_meta = {'obs': 
                {
                'robot0_eef_pos': {'shape': [3]}, 
                'robot0_eef_quat': {'shape': [4]}, 
                'robot0_gripper_qpos': {'shape': [2]}
                },
                'action': {'shape': [10]}
            }
    if 'square' in file_path:
        shape_meta['obs']['agentview_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['obs']['robot0_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['action'] = {'shape': [10]}
        model_file_path = 'square_model_file_1.4.xml'
    elif 'tool_hang' in file_path:
        shape_meta['obs']['sideview_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['obs']['robot0_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['action'] = {'shape': [10]}
        model_file_path = 'tool_hang_model_file_1.4.xml'
    elif 'transport' in file_path:
        shape_meta['obs']['robot0_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['obs']['robot1_eye_in_hand_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['obs']['shouldercamera0_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['obs']['shouldercamera1_image'] = {'shape': [3, 140, 140], 'type': 'rgb'}
        shape_meta['obs']['robot1_eef_pos'] = {'shape': [3]}
        shape_meta['obs']['robot1_eef_quat'] = {'shape': [4]}
        shape_meta['obs']['robot1_gripper_qpos'] = {'shape': [2]}
        shape_meta['action'] = {'shape': [20]}
        model_file_path = 'transport_model_file_1.4.xml'

    env_meta = FileUtils.get_env_metadata_from_dataset(dataset_path=file_path)
    # print('env_meta ', env_meta)
    env_meta['env_kwargs']['controller_configs']['control_delta'] = False
    with open(model_file_path, 'r') as f:
        model_xml = f.read()

    env = create_env(
        env_meta=env_meta, 
        shape_meta=shape_meta
    )

    f_out = h5py.File(output_path, "w")
    data_grp = f_out.create_group("data")
    total_samples = 0
    with h5py.File(file_path, 'r') as file:
        demos = file['data']

        for i in tqdm(range(len(demos)), desc="Loading hdf5 to ReplayBuffer"):
            demo = demos[f'demo_{i}']
            initial_state = dict(states=demo['states'][0])
            initial_state["model"] = model_xml
            env.reset()
            obs = env.reset_to(initial_state)

            if save_noisy_demos:
                traj_obs = []
                traj_states = []
                traj_actions = []
                traj_abs_actions = []
                traj_rewards = []
                traj_dones = []
                traj_obs.append(obs)
                traj_states.append(env.env.sim.get_state().flatten())
                
            agentview_images_saved = []

            for j in range(len(demo['actions'])):
                action = demo['abs_actions'][j]
                # print('action', action)

                prob = np.random.uniform(0, 1)
                if prob < 0.3:
                    means = np.array([0, 0, 0, 0, 0, 0], dtype=np.float32)
                    stds = np.array([0.01, 0.01, 0.01, 0.1, 0.1, 0.1], dtype=np.float32) * 3
                    # if "Transport" not in env._env_name:
                    noise = np.random.normal(loc=means, scale=stds)
                    action[:6] += noise
                    if 'transport' in file_path:
                        noise_2 = np.random.normal(loc=0, scale=stds)
                        action[7:13] += noise_2      

                obs, reward, done, info = env.step(action)
                if save_noisy_demos:
                    traj_obs.append(obs)
                    traj_actions.append(action)
                    traj_abs_actions.append(action)
                    traj_rewards.append(reward)
                    traj_dones.append(done)
                    traj_states.append(env.env.sim.get_state().flatten())

                # agentview_img_rendered = (np.transpose(obs['shouldercamera0_image'], (1, 2, 0)) * 255).astype(np.uint8)    
                # agentview_images_saved.append(agentview_img_rendered)
                # if j == 5:
                #     break
            current_ep = str("demo_{}".format(i))
            print('writing current ep as ', current_ep)
            ep_data_grp = data_grp.create_group(current_ep)

            num_samples = len(traj_actions)
            traj_obs = TensorUtils.list_of_flat_dict_to_dict_of_list(traj_obs)
            for k in traj_obs:
                data = np.array(traj_obs[k][:num_samples])
                if len(data.shape) == 4:
                    # print('key', k)
                    # print('data shape before move axis', data.shape)
                    data = np.moveaxis(data, 1, -1)
                    data =  (data * 255).astype(np.uint8)
                    # print('data shape after move axis', data.shape)
                ep_data_grp.create_dataset("obs/{}".format(k), data=data)

            ep_data_grp.create_dataset("actions", data=np.array(traj_actions))
            ep_data_grp.create_dataset("abs_actions", data=np.array(traj_actions))
            ep_data_grp.create_dataset("states", data=np.array(traj_states[:num_samples]))
            ep_data_grp.create_dataset("rewards", data=np.array(traj_rewards[:num_samples]))
            ep_data_grp.create_dataset("dones", data=np.array(traj_dones[:num_samples]))

            ep_data_grp.attrs["num_samples"] = num_samples # number of transitions in this episode
            print('after trajectory actions ', num_samples)
            total_samples += num_samples
            print("ep {}: wrote {} transitions to {}".format(i, ep_data_grp.attrs["num_samples"], current_ep))
            # save_agentview_video(agentview_images_saved, osp.join('visualizations/noisy_traj', 'episode_{}_abs.mp4'.format(i,)))
            # break

        data_grp.attrs["total"] = total_samples
        print("Wrote {} total transitions".format(total_samples))
        data_grp.attrs["env_args"] = json.dumps(env.serialize(), indent=4) # environment info
        print("Wrote {} trajectories to {}".format(len(demos), output_path))
        