import pickle
import gzip
import numpy as np
import collections
import cv2
import os
import jax
import gym
from collections import defaultdict
from tqdm import tqdm

# d4rl
from jaxrl_m.dataset import Dataset
from sgcrl.hiql_utils import d4rl_utils, d4rl_ant, ant_diagnostics, viz_utils
from jaxrl_m.evaluation import supply_rng, evaluate_with_trajectories, EpisodeMonitor
from jaxrl_m.evaluation import kitchen_render, add_to, flatten

# Procgen
from sgcrl.environments.envs.procgen_viz import get_xy_single

# Roboverse
from gym.wrappers import ClipAction
from roboverse.envs.sawyer_rig_affordances_v6 import SawyerRigAffordancesV6
from rlkit.envs.images import EnvRenderer
from rlkit.envs.images import InsertImageEnv
from rlkit.launchers.contextual.util import get_gym_env
from rlkit.envs.contextual.goal_conditioned import PresampledPathDistribution
from rlkit.envs.images import EnvRenderer
from rlkit.envs.images import InsertImageEnv
from rlkit.launchers.contextual.util import get_gym_env
from rlkit.envs.reward_fns import GoalReachingRewardFn
from rlkit.envs.contextual_env import ContextualEnv
from rlkit.utils.logging import logger as logging

class HIQLEnv(gym.Wrapper):
    def __init__(
        self,
        env_name: str = 'topview',
        discount: float = 1.0,
        visual: bool = True,
        amz_dataset_dir: str = 'antmaze_topview_6_60',
        epsilon: float = 0 
    ):
        """
        The state-based supported envs are: 

         - antmaze-medium-diverse-v2
         - antmaze-medium-play-v2
         - antmaze-large-diverse-v2
         - antmaze-large-play-v2
         - antmaze-ultra-diverse-v0
         - antmaze-ultra-play-v0
         - antmaze-extreme-diverse-v0
         - antmaze-extreme-play-v0
         - kitchen-partial-v0
         - kitchen-mixed-v0
         - calvin
        
        The image-based supported envs are:

         - procgen-maze-500-train
         - procgen-maze-500-test
         - procgen-maze-1000-train
         - procgen-maze-1000-test
         - topview-antmaze-large-diverse-v2
         - topview-antmaze-large-play-v2
         - topview-antmaze-large-diverse-v2
         - roboverse
        """
        
        goal_info = None

        # For antmaze env
        if 'antmaze' in env_name:
            
            # Get gym environment name
            if env_name.startswith('antmaze'):
                gym_env_name = env_name
            else:
                gym_env_name = '-'.join(env_name.split('-')[1:]) # In the case of topview env for instance

            # Add d4rl extended environements if needed
            if 'ultra' or 'extreme' in env_name:
                import d4rl_ext
                import gym
                env = gym.make(gym_env_name)
                env = EpisodeMonitor(env)
            else:
                env = d4rl_utils.make_env(gym_env_name)

            if 'topview' in env_name:
                # Update colors
                l = len(env.model.tex_type)
                # amz-large
                sx, sy, ex, ey = 15, 45, 55, 100
                for i in range(l):
                    if env.model.tex_type[i] == 0:
                        height = env.model.tex_height[i]
                        width = env.model.tex_width[i]
                        s = env.model.tex_adr[i]
                        for x in range(height):
                            for y in range(width):
                                cur_s = s + (x * width + y) * 3
                                R = 192
                                r = int((ex - x) / (ex - sx) * R)
                                g = int((y - sy) / (ey - sy) * R)
                                r = np.clip(r, 0, R)
                                g = np.clip(g, 0, R)
                                env.model.tex_rgb[cur_s:cur_s + 3] = [r, g, 128]
                env.model.mat_texrepeat[0, :] = 1
                orig_env_name = env_name.split('topview-')[1]
            
            env.render(mode='rgb_array', width=64, height=64)
            if 'large' in env_name:
                if 'topview' not in env_name:
                    env.viewer.cam.lookat[0] = 18
                    env.viewer.cam.lookat[1] = 12
                    env.viewer.cam.distance = 50
                    env.viewer.cam.elevation = -90
                else:
                    env.viewer.cam.azimuth = 90.
                    env.viewer.cam.distance = 6
                    env.viewer.cam.elevation = -60

                viz_env, viz_dataset = d4rl_ant.get_env_and_dataset(gym_env_name)
                viz = ant_diagnostics.Visualizer(gym_env_name, viz_env, viz_dataset, discount=discount)
                init_state = np.copy(viz_dataset['observations'][0])
                init_state[:2] = (12.5, 8)
            elif 'ultra' in env_name:
                if 'topview' not in env_name:
                    env.viewer.cam.lookat[0] = 26
                    env.viewer.cam.lookat[1] = 18
                    env.viewer.cam.distance = 70
                    env.viewer.cam.elevation = -90
            elif 'extreme' in env_name:
                if 'topview' not in env_name:
                    env.viewer.cam.lookat[0] = 42
                    env.viewer.cam.lookat[1] = 25
                    env.viewer.cam.distance = 120
                    env.viewer.cam.elevation = -90
            else:
                if 'topview' not in env_name:
                    env.viewer.cam.lookat[0] = 18
                    env.viewer.cam.lookat[1] = 12
                    env.viewer.cam.distance = 50
                    env.viewer.cam.elevation = -90
            if 'onehot' in env_name or 'visual' in env_name or 'topview' in env_name:
                assert visual
                visual_hybrid = True

        # For kitchen env
        elif 'kitchen' in env_name:
            env = d4rl_utils.make_env(env_name) 
        
        # For calvin env
        elif 'calvin' in env_name:
            from sgcrl.environments.envs.calvin import CalvinEnv
            from hydra import compose, initialize
            from hydra.core.global_hydra import GlobalHydra
            from sgcrl.environments.envs.gym_env import GymWrapper
            from sgcrl.environments.envs.gym_env import wrap_env
            initialize(config_path='../environments/envs/conf')
            cfg = compose(config_name='calvin')
            env = CalvinEnv(**cfg)
            env.max_episode_steps = cfg.max_episode_steps = 360
            env = GymWrapper(
                env=env,
                from_pixels=cfg.pixel_ob,
                from_state=cfg.state_ob,
                height=cfg.screen_size[0],
                width=cfg.screen_size[1],
                channels_first=False,
                frame_skip=cfg.action_repeat,
                return_state=False,
            )
            env = wrap_env(env, cfg)

            data = pickle.load(gzip.open('datasets/calvin/calvin.gz', "rb"))
            ds = []
            for i, d in enumerate(data):
                if len(d['obs']) < len(d['dones']):
                    continue  # Skip incomplete trajectories.
                # Only use the first 21 states of non-floating objects.
                d['obs'] = d['obs'][:, :21]
                new_d = dict(
                    observations=d['obs'][:-1],
                    next_observations=d['obs'][1:],
                    actions=d['actions'][:-1],
                )
                num_steps = new_d['observations'].shape[0]
                new_d['rewards'] = np.zeros(num_steps)
                new_d['terminals'] = np.zeros(num_steps, dtype=bool)
                new_d['terminals'][-1] = True
                ds.append(new_d)
        
        # For procgen env
        elif 'procgen' in env_name:
            from sgcrl.environments.envs.procgen_env import ProcgenWrappedEnv, get_procgen_dataset
            import matplotlib

            matplotlib.use('Agg')

            n_processes = 1
            env = ProcgenWrappedEnv(n_processes, 'maze', 1, 1)
        
        # For roboverse env
        elif 'roboverse' in env_name:
            
            # Load default values
            # max_path_length
            # qf_kwargs
            # trainer_kwargs
            # replay_buffer_kwargs
            # online_offline_split_replay_buffer_kwargs
            # policy_kwargs
            # algo_kwargs
            # network_type
            use_image=False
            env_id=None
            env_class=None
            env_kwargs=None
            reward_kwargs=None
            path_loader_kwargs=None
            exploration_policy_kwargs=None
            evaluation_goal_sampling_mode=None

            # Video parameters
            expl_save_video_kwargs=None
            eval_save_video_kwargs=None
            renderer_kwargs=None
            presampled_goal_kwargs=None
            init_camera=None

            # Load additional variables
            env_id = None

            # Load variant variables
            env_kwargs={'test_env': True, 'downsample': True, 'env_obs_img_dim': 196, 'test_env_command': {'drawer_open': False, 'drawer_yaw': 148.969115489417, 'drawer_quadrant': 1, 'small_object_pos': [0.56538715, 0.13505916, -0.35201065], 'small_object_pos_randomness': {'low': [-0.0, -0.0, 0], 'high': [0.0, 0.0, 0]}, 'large_object_quadrant': 2, 'command_sequence': [('move_obj_pnp', {'target_location': 'top'})]}, 'reset_interval': 1}
            reward_kwargs={'obs_type': 'image', 'reward_type': 'sparse', 'epsilon': 2.0, 'terminate_episode': False}
            expl_save_video_kwargs={'save_video_period': 25, 'pad_color': 0}
            eval_save_video_kwargs={'save_video_period': 25, 'pad_color': 0}
            path_loader_kwargs={'delete_after_loading': True, 'recompute_reward': True, 'demo_paths': [{'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_60_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_66_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_6_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_18_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_47_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_7_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_5_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_31_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_51_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_34_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_48_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_69_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_53_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_14_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_54_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_59_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_12_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_56_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_42_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_13_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}], 'split_max_steps': None, 'demo_train_split': 0.95, 'add_demos_to_replay_buffer': True, 'demos_saving_path': None}
            renderer_kwargs={'create_image_format': 'HWC', 'output_image_format': 'CWH', 'flatten_image': True, 'width': 48, 'height': 48}
            evaluation_goal_sampling_mode='presampled_images'
            presampled_goal_kwargs={'eval_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 'eval_goals_kwargs': {}, 'expl_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 'expl_goals_kwargs': {}, 'training_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 'training_goals_kwargs': {}}
            use_image=True
            env_class=SawyerRigAffordancesV6
            init_camera = None
            presampled_goal_kwargs = {
                'eval_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 
                'eval_goals_kwargs': {}, 
                'expl_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 
                'expl_goals_kwargs': {}, 
                'training_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 
                'training_goals_kwargs': {}
            }
            renderer_kwargs = {'create_image_format': 'HWC', 'output_image_format': 'CWH', 'flatten_image': True, 'width': 48, 'height': 48}
            use_image = True

            # Kwarg Definitions
            if exploration_policy_kwargs is None:
                exploration_policy_kwargs = {}
            if presampled_goal_kwargs is None:
                presampled_goal_kwargs = \
                    {'eval_goals': '', 'expl_goals': '', 'training_goals': ''}
            if path_loader_kwargs is None:
                path_loader_kwargs = {}
            if not expl_save_video_kwargs:
                expl_save_video_kwargs = {}
            if not eval_save_video_kwargs:
                eval_save_video_kwargs = {}
            if not renderer_kwargs:
                renderer_kwargs = {}

            if use_image:
                obs_type = 'image'
            else:
                raise NotImplementedError

            obs_key = '%s_observation' % obs_type

            # Make env
            def contextual_env_distrib_and_reward(
                env_id,
                env_class,
                env_kwargs,
                goal_sampling_mode,
                presampled_goals_path,
                reward_kwargs,
            ):
                state_env = get_gym_env(
                    env_id,
                    env_class=env_class,
                    env_kwargs=env_kwargs,
                )
                state_env = ClipAction(state_env)
                renderer = EnvRenderer(
                    init_camera=init_camera,
                    **renderer_kwargs)

                env = InsertImageEnv(
                    state_env,
                    renderer=renderer)

                if goal_sampling_mode == 'presampled_images':
                    diagnostics = state_env.get_contextual_diagnostics
                    context_distribution = PresampledPathDistribution(
                        presampled_goals_path, None, initialize_encodings=False)
                else:
                    raise NotImplementedError

                reward_fn = GoalReachingRewardFn(
                    state_env.env,
                    **reward_kwargs
                )

                contextual_env = ContextualEnv(
                    env,
                    context_distribution=context_distribution,
                    reward_fn=reward_fn,
                    observation_key=obs_key,
                    contextual_diagnostics_fns=[diagnostics] if not isinstance(
                        diagnostics, list) else diagnostics,
                )

                return contextual_env, context_distribution, reward_fn
        
            # Environment Definitions
            logging.info('Preparing the [evaluation] env and contextual distrib...')
            logging.info('Preparing the eval env and contextual distrib...')
            logging.info('sampling mode: %r', evaluation_goal_sampling_mode)
            logging.info('presampled goals: %r',
                        presampled_goal_kwargs['eval_goals'])
            logging.info('presampled goals kwargs: %r',
                        presampled_goal_kwargs['eval_goals_kwargs'],
                        )
            env, context_distrib, reward = contextual_env_distrib_and_reward(
                env_id,
                env_class,
                env_kwargs,
                evaluation_goal_sampling_mode,
                presampled_goal_kwargs['eval_goals'],
                reward_kwargs,
            )
                
        else:
            raise NotImplementedError(0)

        self.env = env
        self.env_name = env_name
        self.epsilon = epsilon 
        self.base_observation, self.goal_infos = self.get_base_observation(env_name, discount, visual, amz_dataset_dir)     
    
    def get_base_observation(
        self,
        env_name: str = 'topview',
        discount: float = 1.0,
        visual: bool = True,
        amz_dataset_dir: str = 'antmaze_topview_6_60',
    ):

        if env_name is None: 
            assert self.env_name is not None
            env_name = self.env_name

        base_observation, goal_infos = None, None

        base_observations_path = f'base_observations/{self.env_name}.npy'
        if os.path.exists(base_observations_path):
            base_observation = np.load(base_observations_path,allow_pickle=True)

        goal_infos_path = f'goal_infos/{self.env_name}.pkl'
        if os.path.exists(goal_infos_path):
            with open(goal_infos_path, 'rb') as file:
                goal_infos = pickle.load(file)

        if (base_observation is None) or (goal_infos is None):
            goal_info = None

            # For antmaze env
            if 'antmaze' in env_name:
                
                # Get gym environment name
                if env_name.startswith('antmaze'):
                    gym_env_name = env_name
                else:
                    gym_env_name = '-'.join(env_name.split('-')[1:]) # In the case of topview env for instance

                # Add d4rl extended environements if needed
                if 'ultra' in env_name:
                    import d4rl_ext
                    import gym
                    env = gym.make(gym_env_name)
                    env = EpisodeMonitor(env)
                else:
                    env = d4rl_utils.make_env(gym_env_name)

                if 'topview' in env_name:
                    # Update colors
                    l = len(env.model.tex_type)
                    # amz-large
                    sx, sy, ex, ey = 15, 45, 55, 100
                    for i in range(l):
                        if env.model.tex_type[i] == 0:
                            height = env.model.tex_height[i]
                            width = env.model.tex_width[i]
                            s = env.model.tex_adr[i]
                            for x in range(height):
                                for y in range(width):
                                    cur_s = s + (x * width + y) * 3
                                    R = 192
                                    r = int((ex - x) / (ex - sx) * R)
                                    g = int((y - sy) / (ey - sy) * R)
                                    r = np.clip(r, 0, R)
                                    g = np.clip(g, 0, R)
                                    env.model.tex_rgb[cur_s:cur_s + 3] = [r, g, 128]
                    env.model.mat_texrepeat[0, :] = 1
                    orig_env_name = env_name.split('topview-')[1]

                    if amz_dataset_dir is not None:
                        dataset = dict(np.load(f'datasets/{amz_dataset_dir}/{orig_env_name}.npz'))
                    else:
                        dataset = dict(np.load(f'datasets/antmaze_topview_6_60/{orig_env_name}.npz'))

                    dataset = Dataset.create(
                        observations= dataset['images'],
                        actions=dataset['actions'],
                        rewards=dataset['rewards'],
                        masks=dataset['masks'],
                        dones_float=dataset['dones_float'],
                        next_observations=dataset['next_images']
                    )
                    # (Precomputed index) The closest observation to the original goal
                    if 'large-diverse' in env_name:
                        target_idx = 38190
                    elif 'large-play' in env_name:
                        target_idx = 798118
                    elif 'ultra-diverse' in env_name:
                        target_idx = 352934
                    elif 'ultra-play' in env_name:
                        target_idx = 77798
                    else:
                        raise NotImplementedError
                    goal_info = {
                        'ob': dataset['observations'][target_idx]
                    }
                else:
                    # onehot or vanilla
                    dataset = d4rl_utils.get_dataset(env, env_name)
                
                dataset = dataset.copy({'rewards': dataset['rewards'] - 1.0})
                goal_infos = [goal_info]
                
                env.render(mode='rgb_array', width=200, height=200)
                if 'large' in env_name:
                    if 'topview' not in env_name:
                        env.viewer.cam.lookat[0] = 18
                        env.viewer.cam.lookat[1] = 12
                        env.viewer.cam.distance = 50
                        env.viewer.cam.elevation = -90
                    else:
                        env.viewer.cam.azimuth = 90.
                        env.viewer.cam.distance = 6
                        env.viewer.cam.elevation = -60

                    viz_env, viz_dataset = d4rl_ant.get_env_and_dataset(gym_env_name)
                    viz = ant_diagnostics.Visualizer(gym_env_name, viz_env, viz_dataset, discount=discount)
                    init_state = np.copy(viz_dataset['observations'][0])
                    init_state[:2] = (12.5, 8)
                elif 'ultra' in env_name:
                    if 'topview' not in env_name:
                        env.viewer.cam.lookat[0] = 26
                        env.viewer.cam.lookat[1] = 18
                        env.viewer.cam.distance = 70
                        env.viewer.cam.elevation = -90
                else:
                    if 'topview' not in env_name:
                        env.viewer.cam.lookat[0] = 18
                        env.viewer.cam.lookat[1] = 12
                        env.viewer.cam.distance = 50
                        env.viewer.cam.elevation = -90
                if 'onehot' in env_name or 'visual' in env_name or 'topview' in env_name:
                    assert visual
                    visual_hybrid = True
                base_observation = jax.tree_map(lambda arr: arr[0], dataset['observations'])

            # For kitchen env
            elif 'kitchen' in env_name:
                env = d4rl_utils.make_env(env_name)
                dataset = d4rl_utils.get_dataset(env, env_name, filter_terminals=True)
                dataset = dataset.copy({'observations': dataset['observations'][:, :30], 'next_observations': dataset['next_observations'][:, :30]})
                goal_infos = [goal_info]
                base_observation = jax.tree_map(lambda arr: arr[0], dataset['observations'])
            
            # For calvin env
            elif 'calvin' in env_name:
                from sgcrl.environments.envs.calvin import CalvinEnv
                from hydra import compose, initialize
                from sgcrl.environments.envs.gym_env import GymWrapper
                from sgcrl.environments.envs.gym_env import wrap_env
                initialize(config_path='../environments/envs/conf')
                cfg = compose(config_name='calvin')
                env = CalvinEnv(**cfg)
                env.max_episode_steps = cfg.max_episode_steps = 360
                env = GymWrapper(
                    env=env,
                    from_pixels=cfg.pixel_ob,
                    from_state=cfg.state_ob,
                    height=cfg.screen_size[0],
                    width=cfg.screen_size[1],
                    channels_first=False,
                    frame_skip=cfg.action_repeat,
                    return_state=False,
                )
                env = wrap_env(env, cfg)

                data = pickle.load(gzip.open('datasets/calvin/calvin.gz', "rb"))
                ds = []
                for i, d in enumerate(data):
                    if len(d['obs']) < len(d['dones']):
                        continue  # Skip incomplete trajectories.
                    # Only use the first 21 states of non-floating objects.
                    d['obs'] = d['obs'][:, :21]
                    new_d = dict(
                        observations=d['obs'][:-1],
                        next_observations=d['obs'][1:],
                        actions=d['actions'][:-1],
                    )
                    num_steps = new_d['observations'].shape[0]
                    new_d['rewards'] = np.zeros(num_steps)
                    new_d['terminals'] = np.zeros(num_steps, dtype=bool)
                    new_d['terminals'][-1] = True
                    ds.append(new_d)
                dataset = dict()
                for key in ds[0].keys():
                    dataset[key] = np.concatenate([d[key] for d in ds], axis=0)
                dataset = d4rl_utils.get_dataset(None, env_name, dataset=dataset)
                goal_infos = [goal_info]
                base_observation = jax.tree_map(lambda arr: arr[0], dataset['observations'])
            
            # For procgen env
            elif 'procgen' in env_name:
                from sgcrl.environments.envs.procgen_env import ProcgenWrappedEnv, get_procgen_dataset
                import matplotlib

                matplotlib.use('Agg')

                n_processes = 1
                env = ProcgenWrappedEnv(n_processes, 'maze', 1, 1)

                if env_name == 'procgen-500':
                    dataset = get_procgen_dataset('datasets/procgen/level500.npz', state_based=('state' in env_name))
                    min_level, max_level = 0, 499
                elif env_name == 'procgen-1000':
                    dataset = get_procgen_dataset('datasets/procgen/level1000.npz', state_based=('state' in env_name))
                    min_level, max_level = 0, 999
                else:
                    raise NotImplementedError

                # Test on large levels having >=20 border states
                large_levels = [12, 34, 35, 55, 96, 109, 129, 140, 143, 163, 176, 204, 234, 338, 344, 369, 370, 374, 410, 430, 468, 470, 476, 491] + [5034, 5046, 5052, 5080, 5082, 5142, 5244, 5245, 5268, 5272, 5283, 5335, 5342, 5366, 5375, 5413, 5430, 5474, 5491]
                goal_infos = []
                goal_infos.append({'eval_level': [level for level in large_levels if min_level <= level <= max_level], 'eval_level_name': 'train'})
                goal_infos.append({'eval_level': [level for level in large_levels if level > max_level], 'eval_level_name': 'test'})

                dones_float = 1.0 - dataset['masks']
                dones_float[-1] = 1.0
                dataset = dataset.copy({
                    'dones_float': dones_float
                })

                discrete = True
                example_action = np.max(dataset['actions'], keepdims=True)
                base_observation = jax.tree_map(lambda arr: arr[0], dataset['observations'])
            
            # For roboverse env
            elif 'roboverse' in env_name:
                
                # Load default values
                # max_path_length
                # qf_kwargs
                # trainer_kwargs
                # replay_buffer_kwargs
                # online_offline_split_replay_buffer_kwargs
                # policy_kwargs
                # algo_kwargs
                # network_type
                use_image=False
                env_id=None
                env_class=None
                env_kwargs=None
                reward_kwargs=None
                path_loader_kwargs=None
                exploration_policy_kwargs=None
                evaluation_goal_sampling_mode=None

                # Video parameters
                expl_save_video_kwargs=None
                eval_save_video_kwargs=None
                renderer_kwargs=None
                presampled_goal_kwargs=None
                init_camera=None

                # Load additional variables
                env_id = None

                # Load variant variables
                env_kwargs={'test_env': True, 'downsample': True, 'env_obs_img_dim': 196, 'test_env_command': {'drawer_open': False, 'drawer_yaw': 148.969115489417, 'drawer_quadrant': 1, 'small_object_pos': [0.56538715, 0.13505916, -0.35201065], 'small_object_pos_randomness': {'low': [-0.0, -0.0, 0], 'high': [0.0, 0.0, 0]}, 'large_object_quadrant': 2, 'command_sequence': [('move_obj_pnp', {'target_location': 'top'})]}, 'reset_interval': 1}
                reward_kwargs={'obs_type': 'image', 'reward_type': 'sparse', 'epsilon': 2.0, 'terminate_episode': False}
                expl_save_video_kwargs={'save_video_period': 25, 'pad_color': 0}
                eval_save_video_kwargs={'save_video_period': 25, 'pad_color': 0}
                path_loader_kwargs={'delete_after_loading': True, 'recompute_reward': True, 'demo_paths': [{'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_60_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_66_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_6_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_18_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_47_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_7_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_5_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_31_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_51_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_34_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_48_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_69_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_53_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_14_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_54_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_59_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_12_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_56_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_42_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}, {'path': './data/env6_td_pnp_push_dataset/scene0_view0_mixed_13_demos.pkl', 'obs_dict': True, 'is_demo': True, 'use_latents': False}], 'split_max_steps': None, 'demo_train_split': 0.95, 'add_demos_to_replay_buffer': True, 'demos_saving_path': None}
                renderer_kwargs={'create_image_format': 'HWC', 'output_image_format': 'CWH', 'flatten_image': True, 'width': 48, 'height': 48}
                evaluation_goal_sampling_mode='presampled_images'
                presampled_goal_kwargs={'eval_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 'eval_goals_kwargs': {}, 'expl_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 'expl_goals_kwargs': {}, 'training_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 'training_goals_kwargs': {}}
                use_image=True
                env_class=SawyerRigAffordancesV6
                init_camera = None
                presampled_goal_kwargs = {
                    'eval_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 
                    'eval_goals_kwargs': {}, 
                    'expl_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 
                    'expl_goals_kwargs': {}, 
                    'training_goals': './data/env6_td_pnp_push_dataset/goals_early_stop/td_pnp_push_scripted_goals_seed12.pkl', 
                    'training_goals_kwargs': {}
                }
                renderer_kwargs = {'create_image_format': 'HWC', 'output_image_format': 'CWH', 'flatten_image': True, 'width': 48, 'height': 48}
                use_image = True

                # Kwarg Definitions
                if exploration_policy_kwargs is None:
                    exploration_policy_kwargs = {}
                if presampled_goal_kwargs is None:
                    presampled_goal_kwargs = \
                        {'eval_goals': '', 'expl_goals': '', 'training_goals': ''}
                if path_loader_kwargs is None:
                    path_loader_kwargs = {}
                if not expl_save_video_kwargs:
                    expl_save_video_kwargs = {}
                if not eval_save_video_kwargs:
                    eval_save_video_kwargs = {}
                if not renderer_kwargs:
                    renderer_kwargs = {}

                if use_image:
                    obs_type = 'image'
                else:
                    raise NotImplementedError

                obs_key = '%s_observation' % obs_type

                # Make env
                def contextual_env_distrib_and_reward(
                    env_id,
                    env_class,
                    env_kwargs,
                    goal_sampling_mode,
                    presampled_goals_path,
                    reward_kwargs,
                ):
                    state_env = get_gym_env(
                        env_id,
                        env_class=env_class,
                        env_kwargs=env_kwargs,
                    )
                    state_env = ClipAction(state_env)
                    renderer = EnvRenderer(
                        init_camera=init_camera,
                        **renderer_kwargs)

                    env = InsertImageEnv(
                        state_env,
                        renderer=renderer)

                    if goal_sampling_mode == 'presampled_images':
                        diagnostics = state_env.get_contextual_diagnostics
                        context_distribution = PresampledPathDistribution(
                            presampled_goals_path, None, initialize_encodings=False)
                    else:
                        raise NotImplementedError

                    reward_fn = GoalReachingRewardFn(
                        state_env.env,
                        **reward_kwargs
                    )

                    contextual_env = ContextualEnv(
                        env,
                        context_distribution=context_distribution,
                        reward_fn=reward_fn,
                        observation_key=obs_key,
                        contextual_diagnostics_fns=[diagnostics] if not isinstance(
                            diagnostics, list) else diagnostics,
                    )

                    return contextual_env, context_distribution, reward_fn
            
                # Environment Definitions
                logging.info('Preparing the [evaluation] env and contextual distrib...')
                logging.info('Preparing the eval env and contextual distrib...')
                logging.info('sampling mode: %r', evaluation_goal_sampling_mode)
                logging.info('presampled goals: %r',
                            presampled_goal_kwargs['eval_goals'])
                logging.info('presampled goals kwargs: %r',
                            presampled_goal_kwargs['eval_goals_kwargs'],
                            )
                env, context_distrib, reward = contextual_env_distrib_and_reward(
                    env_id,
                    env_class,
                    env_kwargs,
                    evaluation_goal_sampling_mode,
                    presampled_goal_kwargs['eval_goals'],
                    reward_kwargs,
                )
                    
                goal_infos = [goal_info]
                base_observation = None

            else:
                raise NotImplementedError(0)

            np.save(base_observations_path,base_observation,allow_pickle=True)
            with open(goal_infos_path, 'wb') as file:
                pickle.dump(goal_infos, file)
            print('Saved base_observations and goal_infos')
        
        return base_observation, goal_infos

    def reset(self, goal_info = None):
         
        if goal_info is None:
            goal_info = np.random.choice(self.goal_infos)

        if 'procgen' in self.env_name:
            from sgcrl.environments.envs.procgen_env import ProcgenWrappedEnv
            from sgcrl.environments.envs.procgen_viz import ProcgenLevel
            eval_level = goal_info['eval_level']
            cur_level = eval_level[np.random.choice(len(eval_level))]

            level_details = ProcgenLevel.create(cur_level)
            border_states = [i for i in range(len(level_details.locs)) if len([1 for j in range(len(level_details.locs)) if abs(level_details.locs[i][0] - level_details.locs[j][0]) + abs(level_details.locs[i][1] - level_details.locs[j][1]) < 7]) <= 2]
            target_state = border_states[np.random.choice(len(border_states))]
            goal_img = level_details.imgs[target_state]
            self.goal_loc = level_details.locs[target_state]
            self.env = ProcgenWrappedEnv(1, 'maze', cur_level, 1)

        observation, done = self.env.reset(), False

        # Set goal
        if 'antmaze' in self.env_name:
            if 'topview' not in self.env_name:
                goal = self.env.wrapped_env.target_goal
                obs_goal = self.base_observation.copy()
                obs_goal[:2] = goal
            else:
                observation = self.env.render(mode='rgb_array', width=64, height=64)
                obs_goal = goal_info['ob']
        elif 'kitchen' in self.env_name:
            observation, obs_goal = observation[:30], observation[30:]
            obs_goal[:9] = self.base_observation[:9]
        elif 'calvin' in self.env_name:
            observation = observation['ob']
            goal = np.array([0.25, 0.15, 0, 0.088, 1, 1])
            obs_goal = self.base_observation.copy()
            obs_goal[15:21] = goal
        elif 'procgen' in self.env_name:
            observation = observation[0]
            obs_goal = goal_img
        elif 'roboverse' in self.env_name:
            obs_goal = observation['image_desired_goal']
            obs_goal = np.swapaxes(np.reshape(obs_goal,(3,48,48)),0,2)
        else:
            raise NotImplementedError
        
        self.obs_goal = obs_goal
        return observation
    
    def sample_action(self):
        if 'procgen' in self.env_name:
            action = np.random.choice([2, 3, 5, 6])
        else:
            action = self.env.action_space.sample()
        return action

    def step(self, action):
        if 'antmaze' in self.env_name:
            next_observation, r, done, info = self.env.step(action)
            
            if 'topview' in self.env_name:
                self.env.viewer.cam.lookat[0] = next_observation[0]
                self.env.viewer.cam.lookat[1] = next_observation[1]
                self.env.viewer.cam.lookat[2] = 0
                next_observation = self.env.render(mode='rgb_array', width=64, height=64)

        elif 'kitchen' in self.env_name:
            next_observation, r, done, info = self.env.step(action)
            next_observation = next_observation[:30] # robot qpos + objects qpos
            goal = next_observation[30:] # target robot qpos + target objects qpos
        elif 'calvin' in self.env_name:
            if not isinstance(action,collections.OrderedDict):
                action = {'ac': np.array(action)}
            next_observation, r, done, info = self.env.step(action)
            next_observation = next_observation['ob']
            del info['robot_info']
            del info['scene_info']
        elif 'procgen' in self.env_name:
            if np.random.random() < self.epsilon:
                action = np.random.choice([2, 3, 5, 6])

            next_observation, r, done, info = self.env.step(np.array([action]))
            next_observation = next_observation[0]
            r = 0.
            done = done[0]
            info = dict()

            loc = get_xy_single(next_observation)
            if np.linalg.norm(loc - self.goal_loc) < 4:
                r = 1.
                done = True

            self.cur_render = next_observation
        
        elif 'roboverse' in self.env_name:
            next_observation, r, done, info = self.env.step(action)
            self.next_observation = next_observation
        
        return next_observation, r, done, info
    
    def render(self, headless=False):
        # Render
        if 'procgen' in self.env_name:
            cur_frame = self.cur_render.copy()
            cur_frame[2, self.goal_loc[1]-1:self.goal_loc[1]+2, self.goal_loc[0]-1:self.goal_loc[0]+2] = 255
            cur_frame[:2, self.goal_loc[1]-1:self.goal_loc[1]+2, self.goal_loc[0]-1:self.goal_loc[0]+2] = 0
            
        elif 'roboverse' in self.env_name:
            cur_frame = self.next_observation['image_observation']
            cur_frame = np.swapaxes(np.reshape(cur_frame,(3,48,48)),0,2)
            cur_frame = (255*cur_frame).astype(np.uint8)
            
        elif 'antmaze' in self.env_name:
            size = 64 if 'topview' in env_name else 1024
            cur_frame = self.env.render(mode='rgb_array', width=size, height=size).copy()
            # if use_waypoints and not config['use_rep'] and ('large' in env_name or 'ultra' in env_name):
            #     def xy_to_pixxy(x, y):
            #         if 'large' in env_name:
            #             pixx = (x / 36) * (0.93 - 0.07) + 0.07
            #             pixy = (y / 24) * (0.21 - 0.79) + 0.79
            #         elif 'ultra' in env_name:
            #             pixx = (x / 52) * (0.955 - 0.05) + 0.05
            #             pixy = (y / 36) * (0.19 - 0.81) + 0.81
            #         return pixx, pixy
            #     x, y = cur_obs_goal_rep[:2]
            #     pixx, pixy = xy_to_pixxy(x, y)
            #     cur_frame[0, int((pixy - 0.02) * size):int((pixy + 0.02) * size), int((pixx - 0.02) * size):int((pixx + 0.02) * size)] = 255
            #     cur_frame[1:3, int((pixy - 0.02) * size):int((pixy + 0.02) * size), int((pixx - 0.02) * size):int((pixx + 0.02) * size)] = 0
            
        elif 'kitchen' in self.env_name:
            cur_frame = kitchen_render(self.env, wh=1024)
        elif 'calvin' in self.env_name:
            cur_frame = self.env.render(mode='rgb_array')
            
        if not headless:
            image_bgr = cv2.cvtColor(cur_frame, cv2.COLOR_RGB2BGR)
            cv2.imshow('Gym Environment', image_bgr)
            cv2.waitKey(1)  # Wait for 1 ms to update the window

        return cur_frame
    
    def set_state(self, state):

        if 'kitchen' in self.env_name:
            current_state = self.sim.get_state()
            new_qpos = state[:30]
            new_qvel = current_state[30:]
            self.sim.data.qpos[:] = new_qpos  # Set the new positions
            self.sim.data.qvel[:] = new_qvel  # Set the new velocities
            self.sim.forward()
        else:
            return self.set_state(state)

    def close(self):
        cv2.destroyAllWindows()
        return self.env.close()

def evaluate_hiql_env(
    env: gym.Env,
    env_name: str,
    num_episodes: int = 1,
    num_video_episodes: int = 1,
    max_steps: int = None,
    headless: bool = True
):
    trajectories = []
    stats = defaultdict(list)

    renders = []
    for i in tqdm(range(num_episodes + num_video_episodes),desc='Evaluating episodes'):
        trajectory = defaultdict(list)
        
        observation, done, = env.reset(), False
        obs_goal = env.obs_goal

        render = []
        step = 0
        while not done:
            
            if (max_steps is not None) and (step >= max_steps):
                break

            print(f'episode {i}: {step} / {max_steps}')

            # Here you select the action
            # if not use_waypoints:
            #     cur_obs_goal = obs_goal
            #     if config['use_rep']:
            #         cur_obs_goal_rep = policy_rep_fn(targets=cur_obs_goal, bases=observation)
            #     else:
            #         cur_obs_goal_rep = cur_obs_goal
            # else:
            #     cur_obs_goal = high_policy_fn(observations=observation, goals=obs_goal, temperature=eval_temperature)
            #     if config['use_rep']:
            #         cur_obs_goal = cur_obs_goal / np.linalg.norm(cur_obs_goal, axis=-1, keepdims=True) * np.sqrt(cur_obs_goal.shape[-1])
            #     else:
            #         cur_obs_goal = observation + cur_obs_goal
            #     cur_obs_goal_rep = cur_obs_goal

            # action = policy_fn(observations=observation, goals=cur_obs_goal_rep, low_dim_goals=True, temperature=eval_temperature)

            action = env.sample_action()

            next_observation, r, done, info = env.step(action)

            step += 1

            # Render
            render.append(env.render(headless=headless))

            transition = dict(
                observation=observation,
                next_observation=next_observation,
                action=action,
                reward=r,
                done=done,
                info=info,
            )
            add_to(trajectory, transition)
            add_to(stats, flatten(info))
            observation = next_observation
            
        if 'calvin' in env_name:
            info['return'] = sum(trajectory['reward'])
        elif 'procgen' in env_name:
            info['return'] = sum(trajectory['reward'])
        add_to(stats, flatten(info, parent_key="final"))
        trajectories.append(trajectory)
        if i >= num_episodes:
            renders.append(np.array(render))
    env.close()
    for k, v in stats.items():
        stats[k] = np.mean(v)

    return stats, trajectories, renders
import matplotlib.pyplot as plt
def save_video_array(
    video_array: np.ndarray,
    file_path: str,
    video_name: str,
    fps: int = 25,
    convert_to_uint8: bool = False
):
    """
    video_array: np.ndarray of shape (t,h,w,c)
    """
    # get array shape
    T, H, W, C = video_array.shape

    # if needed, convert to uint
    if convert_to_uint8:
        video_array = (video_array * 255).astype(np.uint8)
    
    video_array = video_array[..., ::-1]

    # save video
    out = cv2.VideoWriter(os.path.join(file_path,f'{video_name}.mp4'), cv2.VideoWriter_fourcc(*'mp4v'), fps, (H, W), C > 1)
    for t in range(T):
        data = video_array[t]
        out.write(data)
    out.release()

def renders_to_array(renders=None, skip_frames=1):
    max_length = max([len(render) for render in renders])
    for i, render in enumerate(renders):
        renders[i] = np.concatenate([render, np.zeros((max_length - render.shape[0], *render.shape[1:]), dtype=render.dtype)], axis=0)
        renders[i] = renders[i][::skip_frames]
    renders = np.array(renders)
    return renders

if __name__ == '__main__':

    # Select env
    # env_name = 'topview-antmaze-large-play-v2'
    # env_name = 'antmaze-large-play-v2'
    # env_name = 'antmaze-ultra-play-v0'
    # env_name = 'antmaze-extreme-play-v0'
    env_name = 'kitchen-mixed-v0'
    # env_name = 'calvin'
    # env_name = 'procgen-500'
    # env_name = 'roboverse'

    # Get env
    env = HIQLEnv(env_name)

    # Test loop
    env.reset()
    step = 0
    while step < 10:
        env.step(env.sample_action())
        env.render()
        step += 1
        print(step)
    env.close()

    # Evaluate env
    eval_info, trajs, renders = evaluate_hiql_env(
        env=env,
        env_name=env_name,
        num_episodes=1,
        num_video_episodes=1,
        max_steps=5,
        headless=False
    )

    # Record videos
    video_arrays = renders_to_array(renders)
    for i,video_array in enumerate(video_arrays):
        save_video_array(
            video_array=video_array,
            file_path='videos/evals',
            video_name=f'{env_name}_{i}',
            fps=25,
            convert_to_uint8=False,
        )