#!/usr/bin/env python3
import tempfile

import dowel_wrapper

assert dowel_wrapper is not None
import dowel

import wandb

import argparse
import datetime
import functools
import os
import sys
import copy
import platform
import torch.multiprocessing as mp

if "mac" in platform.platform():
    pass
else:
    os.environ["MUJOCO_GL"] = "egl"
    if "SLURM_STEP_GPUS" in os.environ:
        os.environ["EGL_DEVICE_ID"] = os.environ["SLURM_STEP_GPUS"]
    os.environ["EGL_DEVICE_ID"] = "0"  # fix

import better_exceptions
import numpy as np

better_exceptions.hook()

import torch

from garage import wrap_experiment
from garage.experiment.deterministic import set_seed
from garage.torch.distributions import TanhNormal

from garagei.replay_buffer.path_buffer_ex import PathBufferEx
from garagei.experiment.option_local_runner import OptionLocalRunner
from garagei.envs.consistent_normalized_env import consistent_normalize
from garagei.sampler.option_multiprocessing_sampler import OptionMultiprocessingSampler
from garagei.torch.modules.with_encoder import WithEncoder, Encoder, DimensionsSelector
from garagei.torch.modules.gaussian_mlp_module_ex import (
    GaussianMLPTwoHeadedModuleEx,
    GaussianMLPIndependentStdModuleEx,
    GaussianMLPModuleEx,
)
from garagei.torch.modules.gaussian_lstm_module_ex import (
    GaussianLSTMTwoHeadedModuleEx,
    GaussianLSTMIndependentStdModuleEx,
    GaussianLSTMModuleEx,
)
from garagei.torch.modules.parameter_module import ParameterModule
from garagei.torch.policies.policy_ex import PolicyEx, RecurrentPolicyEx
from garagei.torch.q_functions.continuous_mlp_q_function_ex import (
    ContinuousMLPQFunctionEx,
)
from garagei.torch.q_functions.discrete_mlp_q_function_ex import (
    DiscreteMLPQFunctionEx,
)

from garagei.torch.q_functions.continuous_lstm_q_function_ex import (
    ContinuousLSTMQFunctionEx,
)
from garagei.torch.optimizers.optimizer_group_wrapper import OptimizerGroupWrapper
from garagei.torch.utils import xavier_normal_ex
from iod.metra import METRA
from iod.recurrent_metra import RecurrentMETRA
from iod.dads import DADS
from iod.utils import get_normalizer_preset
from iod.apt_utils import ICM
from iod.rnd_utils import RND
from iod.disagreement import Disagreement

from envs import make_model

EXP_DIR = "exp/" if "SLURM_JOB_ID" not in os.environ else "./exp/"
if os.environ.get("START_METHOD") is not None:
    START_METHOD = os.environ["START_METHOD"]
else:
    START_METHOD = "spawn"


def get_argparser():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument("--run_group", type=str, default="Debug")
    parser.add_argument(
        "--normalizer_type", type=str, default="off", choices=["off", "preset"]
    )
    parser.add_argument("--encoder", type=int, default=0)

    parser.add_argument(
        "--env",
        type=str,
        default="maze",
        # choices=[
        #     "maze",
        #     "half_cheetah",
        #     "half_cheetah_short",
        #     "half_cheetah_short_hurdle",
        #     "half_cheetah_short_hurdle_baseline",
        #     "walker",
        #     "ant",
        #     "ant_gap",
        #     "ant_steps",
        #     "ant_hill",
        #     "ant_hill_half",
        #     "ant_hill_halfstep_uniformreset",
        #     "ant_hill_uniformreset",
        #     "dmc_cheetah",
        #     "dmc_quadruped",
        #     "dmc_quadruped_state_escape",
        #     "dmc_quadruped_state_run_forward",
        #     "dmc_quadruped_state_step",
        #     "dmc_quadruped_state_fetch",
        #     "dmc_humanoid",
        #     "dmc_humanoid_state",
        #     "dmc_dog_run",
        #     "dmc_jaco_state",
        #     "kitchen",
        #     "pybullet_ant",
        #     "antmaze-pixel",
        #     "antmaze-original",
        #     "antmaze-large-play",
        #     "antmaze",
        #     "antmaze-pixel-egocentric",
        #     "antmaze-pixel-wall",
        #     "antmaze-pixel-wall-egocentric",
        #     "antmaze-hybrid-egocentric",
        #     "antmaze-hybrid-wall",
        #     "antmaze-hybrid-wall-egocentric",
        #     "antmaze-hybrid",
        #     "antmaze-hybrid-partialinfo",
        #     "antmaze-hybrid-wall-partialinfo",
        #     "antmaze-hybrid-wall-rotate",
        #     "antmaze-hybrid-noinfo",
        #     "maze-fixed",
        #     "toymaze",
        #     "toymaze-pixel",
        #     "pointmaze-large",
        #     "pointmaze-verylarge",
        #     "ant_maze_diversestates",
        #     "lifelong_hopper",
        #     "fetchpush",
        #     "state_kitchen",
        # ],
    )
    parser.add_argument("--frame_stack", type=int, default=None)
    parser.add_argument("--frame_stack_only_to_policy", type=int, default=None)

    parser.add_argument("--max_path_length", type=int, default=200)

    parser.add_argument("--use_gpu", type=int, default=1, choices=[0, 1])
    parser.add_argument("--sample_cpu", type=int, default=1, choices=[0, 1])

    class SeedAction(argparse.Action):
        def __init__(
            self, option_strings, dest, default=None, required=False, help=None
        ):
            super(SeedAction, self).__init__(
                option_strings,
                dest,
                nargs="?",
                const=0,
                default=default,
                type=int,
                required=required,
                help=help,
            )

        def __call__(self, parser, namespace, values, option_string=None):
            if values is None:
                values = self.const
            setattr(namespace, self.dest, values)

    parser.add_argument("--seed", action=SeedAction, default=0)
    parser.add_argument("--n_parallel", type=int, default=4)
    parser.add_argument("--n_thread", type=int, default=1)

    parser.add_argument("--n_epochs", type=int, default=1000000)
    parser.add_argument("--traj_batch_size", type=int, default=8)
    parser.add_argument("--trans_minibatch_size", type=int, default=256)
    parser.add_argument("--trans_optimization_epochs", type=int, default=200)

    parser.add_argument("--n_epochs_per_eval", type=int, default=125)
    parser.add_argument("--n_epochs_per_log", type=int, default=25)
    parser.add_argument("--n_epochs_per_save", type=int, default=1000)
    parser.add_argument("--n_epochs_per_pt_save", type=int, default=1000)
    parser.add_argument("--n_epochs_per_pkl_update", type=int, default=None)
    parser.add_argument("--num_random_trajectories", type=int, default=48)
    parser.add_argument("--num_video_repeats", type=int, default=2)
    parser.add_argument("--eval_record_video", type=int, default=1)
    parser.add_argument("--eval_plot_axis", type=float, default=None, nargs="*")
    parser.add_argument("--video_skip_frames", type=int, default=1)

    parser.add_argument("--dim_option", type=int, default=2)

    parser.add_argument("--common_lr", type=float, default=1e-4)
    parser.add_argument("--lr_op", type=float, default=None)
    parser.add_argument("--lr_te", type=float, default=None)

    parser.add_argument("--alpha", type=float, default=0.01)

    parser.add_argument("--algo", type=str, default="metra", choices=["metra", "dads"])

    parser.add_argument("--sac_tau", type=float, default=5e-3)
    parser.add_argument("--sac_lr_q", type=float, default=None)
    parser.add_argument("--sac_lr_a", type=float, default=None)
    parser.add_argument("--exploration_sac_lr_q", type=float, default=None)
    parser.add_argument("--exploration_sac_lr_a", type=float, default=None)
    parser.add_argument("--exploration_lr_op", type=float, default=None)
    parser.add_argument("--sac_discount", type=float, default=0.99)
    parser.add_argument("--exploration_sac_discount", type=float, default=0.99)
    parser.add_argument("--sac_scale_reward", type=float, default=1.0)
    parser.add_argument("--sac_target_coef", type=float, default=1.0)
    parser.add_argument("--sac_min_buffer_size", type=int, default=10000)
    parser.add_argument("--sac_max_buffer_size", type=int, default=300000)

    parser.add_argument("--spectral_normalization", type=int, default=0, choices=[0, 1])

    parser.add_argument("--model_master_dim", type=int, default=1024)
    parser.add_argument("--model_master_num_layers", type=int, default=2)
    parser.add_argument(
        "--model_master_nonlinearity", type=str, default=None, choices=["relu", "tanh"]
    )
    parser.add_argument("--sd_const_std", type=int, default=1)
    parser.add_argument("--sd_batch_norm", type=int, default=1, choices=[0, 1])

    parser.add_argument("--num_alt_samples", type=int, default=100)
    parser.add_argument("--split_group", type=int, default=65536)

    parser.add_argument("--discrete", type=int, default=0, choices=[0, 1])
    parser.add_argument("--inner", type=int, default=1, choices=[0, 1])
    parser.add_argument(
        "--unit_length", type=int, default=1, choices=[0, 1]
    )  # Only for continuous skills

    parser.add_argument("--dual_reg", type=int, default=1, choices=[0, 1])
    parser.add_argument("--dual_lam", type=float, default=30)
    parser.add_argument("--dual_slack", type=float, default=1e-3)
    parser.add_argument(
        "--dual_dist",
        type=str,
        default="one",
        choices=["l2", "s2_from_s", "gt", "one", "quasimetric"],
    )
    parser.add_argument("--dual_lr", type=float, default=None)

    # recurrent_metra
    parser.add_argument("--recurrent", type=int, default=0, choices=[0, 1])
    parser.add_argument("--traj_minibatch_size", type=int, default=16)
    parser.add_argument("--description", type=str, default=None)
    parser.add_argument("--no_reset", type=int, default=0, choices=[0, 1])
    parser.add_argument("--use_pure_rewards", type=int, default=0, choices=[0, 1, 2])
    parser.add_argument("--traj_encoder_dims", type=int, default=None)
    parser.add_argument("--traj_encoder_num_layers", type=int, default=None)
    parser.add_argument("--traj_layer_normalization", type=int, default=None)
    parser.add_argument("--q_layer_normalization", type=int, default=None)
    parser.add_argument("--exp_q_layer_normalization", type=int, default=None)

    # max_path_length is important too (BPTT)
    parser.add_argument("--hurdle_height", type=float, default=None)
    parser.add_argument("--gap_size", type=float, default=None)
    parser.add_argument("--hill_height", type=float, default=None)

    parser.add_argument(
        "--asymmetric", type=int, default=0, choices=[0, 1]
    )  # use asymmetric observation for encoder / policy

    parser.add_argument("--use_model", type=int, default=0, choices=[0, 1])
    parser.add_argument(
        "--use_random_options_for_exploration", type=int, default=0, choices=[0, 1]
    )
    parser.add_argument("--dot_penalty", type=int, default=0, choices=[0, 1])

    parser.add_argument("--from_dir", type=str, default=None)

    parser.add_argument("--option_freq", type=int, default=0)
    parser.add_argument("--use_start_policy", type=int, default=0)

    parser.add_argument(
        "--exploration_type",
        type=int,
        default=0,
        choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
    )

    parser.add_argument("--perpendicular", type=int, default=0, choices=[0, 1])

    parser.add_argument("--apt", type=int, default=0, choices=[0, 1])
    parser.add_argument("--apt_icm", type=int, default=0, choices=[0, 1])
    parser.add_argument("--icm_lr", type=float, default=None)
    parser.add_argument("--use_traj_for_apt_rep", type=int, default=0, choices=[0, 1])

    parser.add_argument("--hierarchical", type=int, default=0, choices=[0, 1])
    parser.add_argument("--hierarchical_dim", type=int, default=16)
    parser.add_argument("--hierarchical_lr", type=float, default=None)

    parser.add_argument("--prevupd", type=int, default=0, choices=[0, 1])
    parser.add_argument("--prevupd_dim_exp_option", type=int, default=1)
    parser.add_argument("--prevupd_lr_q", type=float, default=None)
    parser.add_argument("--prevupd_lr_a", type=float, default=None)
    parser.add_argument("--prevupd_batch_size", type=int, default=1024)
    parser.add_argument("--prevupd_freq", type=int, default=500)
    parser.add_argument("--debug_noconst", type=int, default=0)
    parser.add_argument("--debug_val", type=int, default=-1)
    parser.add_argument(
        "--prevupd_use_same_policy", type=int, default=0, choices=[0, 1]
    )

    parser.add_argument("--no_plot", type=int, default=0, choices=[0, 1])

    parser.add_argument("--hilp", type=int, default=0, choices=[0, 1])
    parser.add_argument("--hilp_expectile", type=float, default=0.95)
    parser.add_argument("--hilp_traj_encoder_tau", type=float, default=0.005)
    parser.add_argument("--hilp_qrl", type=int, default=0, choices=[0, 1])
    parser.add_argument("--hilp_p_trajgoal", type=float, default=0.625)
    parser.add_argument("--hilp_discount", type=float, default=0.99)

    parser.add_argument("--goal_reaching", type=int, default=0, choices=[0, 1])
    parser.add_argument("--use_goal_checker", type=int, default=0, choices=[0, 1])

    parser.add_argument("--cp_path", type=str, default=None)
    parser.add_argument("--cp_path_idx", type=int, default=None)  # For exp name
    parser.add_argument("--cp_multi_step", type=int, default=1)
    parser.add_argument("--cp_unit_length", type=int, default=0)

    parser.add_argument("--downstream_reward_type", type=str, default="esparse")
    parser.add_argument("--downstream_num_goal_steps", type=int, default=50)

    parser.add_argument("--goal_range", type=float, default=50)

    parser.add_argument("--use_double_encoder", type=int, default=0, choices=[0, 1])
    parser.add_argument(
        "--encoder_layer_normalization", type=int, default=1, choices=[0, 1]
    )
    parser.add_argument("--save_replay_buffer", type=int, default=0, choices=[0, 1])

    parser.add_argument(
        "--uniform_sample_replay_buffer", type=int, default=0, choices=[0, 1]
    )
    parser.add_argument("--restore_buffer_from", type=str, default=None)

    parser.add_argument("--qf_dims", type=int, default=None)
    parser.add_argument("--qf_num_layers", type=int, default=None)
    parser.add_argument("--policy_dims", type=int, default=None)
    parser.add_argument("--policy_num_layers", type=int, default=None)
    parser.add_argument("--exp_qf_dims", type=int, default=None)
    parser.add_argument("--exp_qf_num_layers", type=int, default=None)
    parser.add_argument("--exp_policy_dims", type=int, default=None)
    parser.add_argument("--exp_policy_num_layers", type=int, default=None)

    parser.add_argument("--visualize_rewards", type=int, default=0, choices=[0, 1])
    parser.add_argument("--plot_tsne", type=int, default=0, choices=[0, 1])

    parser.add_argument("--use_target_traj", type=int, default=0, choices=[0, 1])
    parser.add_argument("--use_repelling", type=int, default=0, choices=[0, 1])

    parser.add_argument("--knn_k", type=int, default=12)
    # parser.add_argument("--no_hindsight", type=int, default=0, choices=[0, 1]) # debug_noconst: 404

    parser.add_argument("--rnd", type=int, default=0)
    parser.add_argument("--disagreement", type=int, default=0)
    return parser


args = get_argparser().parse_args()
g_start_time = int(datetime.datetime.now().timestamp())


def get_exp_name():
    exp_name = ""
    exp_name += f"sd{args.seed:03d}_"
    if "SLURM_JOB_ID" in os.environ:
        exp_name += f's_{os.environ["SLURM_JOB_ID"]}.'
    if "SLURM_PROCID" in os.environ:
        exp_name += f'{os.environ["SLURM_PROCID"]}.'
    exp_name_prefix = exp_name
    if "SLURM_RESTART_COUNT" in os.environ:
        exp_name += f'rs_{os.environ["SLURM_RESTART_COUNT"]}.'
    exp_name += f"{g_start_time}"

    exp_name += "_" + args.env
    exp_name += "_" + args.algo
    if args.description:
        exp_name += "_" + args.description

    return exp_name, exp_name_prefix


def get_log_dir():
    exp_name, exp_name_prefix = get_exp_name()
    assert len(exp_name) <= os.pathconf("/", "PC_NAME_MAX")
    # Resolve symlinks to prevent runs from crashing in case of home nfs crashing.
    log_dir = os.path.realpath(os.path.join(EXP_DIR, args.run_group, exp_name))
    assert not os.path.exists(log_dir), f"The following path already exists: {log_dir}"

    return log_dir


def get_gaussian_module_construction(
    args,
    *,
    hidden_sizes,
    const_std=False,
    hidden_nonlinearity=torch.relu,
    w_init=torch.nn.init.xavier_uniform_,
    init_std=1.0,
    min_std=1e-6,
    max_std=None,
    **kwargs,
):
    module_kwargs = dict()
    if const_std:
        module_cls = GaussianMLPModuleEx if not args.recurrent else GaussianLSTMModuleEx
        module_kwargs.update(
            dict(
                learn_std=False,
                init_std=init_std,
            )
        )
    else:
        module_cls = (
            GaussianMLPIndependentStdModuleEx
            if not args.recurrent
            else GaussianLSTMIndependentStdModuleEx
        )
        module_kwargs.update(
            dict(
                std_hidden_sizes=hidden_sizes,
                std_hidden_nonlinearity=hidden_nonlinearity,
                std_hidden_w_init=w_init,
                std_output_w_init=w_init,
                init_std=init_std,
                min_std=min_std,
                max_std=max_std,
            )
        )

    module_kwargs.update(
        dict(
            hidden_sizes=hidden_sizes,
            hidden_nonlinearity=hidden_nonlinearity,
            hidden_w_init=w_init,
            std_parameterization="exp",
            bias=True,
            spectral_normalization=args.spectral_normalization,
            **kwargs,
        )
    )
    return module_cls, module_kwargs


def make_env(args, max_path_length):
    if args.env == "maze":
        # from envs.maze_env import MazeEnv
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        # env = MazeEnv(
        #     max_path_length=max_path_length,
        #     action_range=0.2,
        # )

        env = gym.make("maze2d-umaze-v1")
        env = RenderWrapper(env)
    elif args.env == "maze-fixed":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("maze2dfixed-umaze-v1")
        env = RenderWrapper(env)
    elif args.env == "toymaze":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("maze2d-toymaze-v1")
        env = RenderWrapper(env, camera_id=None)
    elif args.env == "toymaze-pixel":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("maze2d-toymaze-v1")
        env = RenderWrapper(env, pixel=True)
    elif args.env == "pointmaze-large":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import MazeRenderWrapper

        env = gym.make("maze2d-largefix-dense-v1")
        env = MazeRenderWrapper(env, camera_id=None)
    elif args.env == "pointmaze-verylarge":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("maze2d-verylargefix-dense-v1")
        env = RenderWrapper(env, camera_id=None)
    elif args.env == "antmaze-pixel":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, pixel=True, floor_color=True)
    elif args.env == "antmaze-original":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import MazeRenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = MazeRenderWrapper(env, pixel=False)
    elif args.env == "antmaze-large-play":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import MazeRenderWrapper

        env = gym.make("antmaze-custom-large-play-v2")
        env = MazeRenderWrapper(env, pixel=False)
    elif args.env == "antmaze-ultra-play":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import MazeRenderWrapper

        env = gym.make("antmaze-custom-ultra-play-v0")
        env = MazeRenderWrapper(env, pixel=False)
    elif args.env == "antmaze":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, pixel=False)
    elif args.env == "antmaze-pixel-egocentric":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, pixel=True, floor_color=True, camera_id=3)

    elif args.env == "antmaze-pixel-wall":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, pixel=True, wall_color=True)
    elif args.env == "antmaze-pixel-wall-egocentric":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, pixel=True, wall_color=True, camera_id=3)
    elif args.env == "antmaze-hybrid":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, hybrid=True, floor_color=True)
    elif args.env == "antmaze-hybrid-egocentric":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, hybrid=True, floor_color=True, camera_id=3)
    elif args.env == "antmaze-hybrid-wall":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, hybrid=True, wall_color=True)
    elif args.env == "antmaze-hybrid-wall-egocentric":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, hybrid=True, wall_color=True, camera_id=3)
    elif args.env == "antmaze-hybrid-partialinfo":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(
            env, hybrid=True, floor_color=True, camera_id=4, partialinfo=True
        )
    elif args.env == "antmaze-hybrid-wall-partialinfo":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(
            env, hybrid=True, wall_color=True, camera_id=4, partialinfo=True
        )
    elif args.env == "antmaze-hybrid-wall-rotate":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(env, hybrid=True, wall_color=True, camera_id=4)
    elif args.env == "antmaze-hybrid-noinfo":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmaze-large-playinthemiddle-v2")
        env = RenderWrapper(
            env, hybrid=True, floor_color=True, camera_id=4, noinfo=True
        )
    elif args.env == "ant_maze_diversestates":
        import gym
        import d4rl
        from envs.d4rl.pixel_wrappers import RenderWrapper

        env = gym.make("antmazediversestates-umaze-v2")
        env = RenderWrapper(env)
    elif args.env == "half_cheetah":
        from envs.mujoco.half_cheetah_env import HalfCheetahEnv

        # env = HalfCheetahEnv(render_hw=100, expose_all_qpos=False)
        env = HalfCheetahEnv(render_hw=100)
    elif args.env == "half_cheetah_short":
        from envs.mujoco.half_cheetah_env import HalfCheetahEnv

        # env = HalfCheetahEnv(render_hw=100, expose_all_qpos=False)
        env = HalfCheetahEnv(
            render_hw=100, model_path="half_cheetah_short.xml", fixed_initial_state=True
        )
    elif args.env == "half_cheetah_short_hurdle":
        from envs.mujoco.half_cheetah_env import HalfCheetahEnv

        # env = HalfCheetahEnv(render_hw=100, expose_all_qpos=False)
        env = HalfCheetahEnv(
            render_hw=100,
            model_path="half_cheetah_short_hurdle.xml",
            hurdle_height=args.hurdle_height,
        )
    elif args.env == "half_cheetah_short_hurdle_baseline":
        from envs.mujoco.half_cheetah_env import HalfCheetahEnv

        # env = HalfCheetahEnv(render_hw=100, expose_all_qpos=False)
        env = HalfCheetahEnv(
            render_hw=100,
            model_path="half_cheetah_short_hurdle_baseline.xml",
        )
    elif args.env == "walker":
        from envs.mujoco.walker_env import Walker2dEnv

        env = Walker2dEnv()
    elif args.env == "ant":
        from envs.mujoco.ant_env import AntEnv

        env = AntEnv(render_hw=100)
    elif args.env == "ant_gap":
        from envs.mujoco.ant_env import AntEnv

        env = AntEnv(render_hw=100, model_path="ant_gap.xml", gap_size=args.gap_size)
    elif args.env == "ant_steps":
        from envs.mujoco.ant_env import AntEnv

        env = AntEnv(
            render_hw=100,
            model_path="ant_steps.xml",
            gap_size=args.gap_size,
            task="forward",
        )
    elif args.env == "ant_hill":
        from envs.mujoco.ant_env import AntEnv

        env = AntEnv(
            render_hw=100,
            model_path="ant_hill.xml",
            hill_height=args.hill_height,
        )
    elif args.env == "ant_hill_half":
        from envs.mujoco.ant_env import AntEnv

        env = AntEnv(
            render_hw=100,
            model_path="ant_hill.xml",
            hill_height=args.hill_height,
            half=True,
        )
    elif args.env == "ant_hill_halfstep_uniformreset":
        from envs.mujoco.ant_env import AntEnv

        env = AntEnv(
            render_hw=100,
            model_path="ant_hill_halfstep.xml",
            hill_height=args.hill_height,
            uniform_reset=True,
            halfstep=True,
        )
    elif args.env == "ant_hill_uniformreset":
        from envs.mujoco.ant_env import AntEnv

        env = AntEnv(
            render_hw=100,
            model_path="ant_hill.xml",
            hill_height=args.hill_height,
            uniform_reset=True,
        )
    elif args.env == "lifelong_hopper":
        from envs.mujoco.hopper_env import LifelongHopperEnv

        env = LifelongHopperEnv()

    elif args.env == "dmc_humanoid_state":
        from envs.custom_dmc_tasks import dmc

        env = dmc.make(
            "humanoid_run_pure_state",
            obs_type="states",
            frame_stack=1,
            action_repeat=2,
            seed=args.seed,
        )
    elif args.env == "dmc_quadruped_state_escape":
        from envs.custom_dmc_tasks import dmc

        env = dmc.make(
            "quadruped_escape",
            obs_type="states",
            frame_stack=1,
            action_repeat=2,
            seed=args.seed,
            task_kwargs={
                "random": args.seed,
            },
        )
    elif args.env == "dmc_quadruped_state_run_forward":
        from envs.custom_dmc_tasks import dmc

        env = dmc.make(
            "quadruped_run_forward",
            obs_type="states",
            frame_stack=1,
            action_repeat=2,
            seed=args.seed,
            task_kwargs={
                "random": args.seed,
            },
        )
    elif args.env == "dmc_quadruped_state_step":
        from envs.custom_dmc_tasks import dmc

        env = dmc.make(
            "quadruped_step",
            obs_type="states",
            frame_stack=1,
            action_repeat=2,
            seed=args.seed,
            task_kwargs={
                "random": args.seed,
            },
        )
    elif args.env == "dmc_quadruped_state_fetch":
        from envs.custom_dmc_tasks import dmc

        env = dmc.make(
            "quadruped_fetch",
            obs_type="states",
            frame_stack=1,
            action_repeat=2,
            seed=args.seed,
            task_kwargs={
                "random": args.seed,
            },
        )
    elif args.env == "dmc_dog_run":
        from envs.custom_dmc_tasks import dmc

        env = dmc.make(
            "dog_run",
            obs_type="states",
            frame_stack=1,
            action_repeat=2,
            seed=args.seed,
        )
    elif args.env == "dmc_jaco_state":
        from envs.custom_dmc_tasks import dmc

        env = dmc.make(
            "jaco_reach_top_left",
            obs_type="states",
            frame_stack=1,
            action_repeat=2,
            seed=args.seed,
        )
    elif args.env == "ant_nav_prime":
        from envs.mujoco.ant_nav_prime_env import AntNavPrimeEnv

        env = AntNavPrimeEnv(
            max_path_length=max_path_length,
            goal_range=args.goal_range,
            num_goal_steps=args.downstream_num_goal_steps,
            reward_type=args.downstream_reward_type,
        )
        cp_num_truncate_obs = 2
    elif args.env == "half_cheetah_goal":
        from envs.mujoco.half_cheetah_goal_env import HalfCheetahGoalEnv

        env = HalfCheetahGoalEnv(
            max_path_length=max_path_length,
            goal_range=args.goal_range,
            reward_type=args.downstream_reward_type,
        )
        cp_num_truncate_obs = 1
    elif args.env == "half_cheetah_hurdle":
        from envs.mujoco.half_cheetah_hurdle_env import HalfCheetahHurdleEnv

        env = HalfCheetahHurdleEnv(
            reward_type=args.downstream_reward_type,
        )
        cp_num_truncate_obs = 2

    elif args.env.startswith("dmc"):
        from envs.custom_dmc_tasks import dmc
        from envs.custom_dmc_tasks.pixel_wrappers import RenderWrapper

        assert args.encoder  # Only support pixel-based environments
        if args.env == "dmc_cheetah":
            env = dmc.make(
                "cheetah_run_forward_color",
                obs_type="states",
                frame_stack=1,
                action_repeat=2,
                seed=args.seed,
            )
            env = RenderWrapper(env)
        elif args.env == "dmc_quadruped":
            env = dmc.make(
                "quadruped_run_forward_color",
                obs_type="states",
                frame_stack=1,
                action_repeat=2,
                seed=args.seed,
            )
            env = RenderWrapper(env)
        elif args.env == "dmc_humanoid":
            env = dmc.make(
                "humanoid_run_color",
                obs_type="states",
                frame_stack=1,
                action_repeat=2,
                seed=args.seed,
            )
            env = RenderWrapper(env)
        else:
            raise NotImplementedError
    elif args.env == "kitchen":
        sys.path.append("lexa")
        from envs.lexa.mykitchen import MyKitchenEnv

        assert args.encoder  # Only support pixel-based environments
        env = MyKitchenEnv(log_per_goal=True)

    elif args.env == "pybullet_ant":
        import envs.pomdp
        import gym

        env = gym.make("AntBLT-F-v0")
    elif args.env == "fetchpush":
        from envs.general_env_wrapper import FetchPushEnv

        env = FetchPushEnv()
    elif args.env == "state_kitchen":
        sys.path.append("lexa")
        from envs.lexa.mystatekitchen import MyKitchenEnv

        env = MyKitchenEnv(log_per_goal=True)

    else:
        raise NotImplementedError

    if args.frame_stack is not None:
        from envs.custom_dmc_tasks.pixel_wrappers import FrameStackWrapper

        env = FrameStackWrapper(env, args.frame_stack)

    normalizer_type = args.normalizer_type
    normalizer_kwargs = {}

    if normalizer_type == "off":
        env = consistent_normalize(env, normalize_obs=False, **normalizer_kwargs)
    elif normalizer_type == "preset":
        normalizer_name = args.env
        normalizer_mean, normalizer_std = get_normalizer_preset(
            f"{normalizer_name}_preset"
        )
        additional_dim = 0
        if args.env in ["ant_nav_prime"]:
            normalizer_name = "ant"
            additional_dim = cp_num_truncate_obs
        elif args.env in ["half_cheetah_goal", "half_cheetah_hurdle"]:
            normalizer_name = "half_cheetah"
            additional_dim = cp_num_truncate_obs
        elif args.env in ["dmc_dog_run"]:
            normalizer_name = "dmc_dog_run"
            additional_dim = 303 - 223
        else:
            normalizer_name = args.env

        if additional_dim > 0:
            normalizer_mean = np.concatenate(
                [normalizer_mean, np.zeros(additional_dim)]
            )
            normalizer_std = np.concatenate([normalizer_std, np.ones(additional_dim)])
        # print(normalizer_mean.shape, env.observation_space.shape)
        env = consistent_normalize(
            env,
            normalize_obs=True,
            mean=normalizer_mean,
            std=normalizer_std,
            **normalizer_kwargs,
        )

    return env


def get_runner(args, ctxt):
    if "WANDB_API_KEY" in os.environ:
        wandb_output_dir = ctxt.snapshot_dir  # tempfile.mkdtemp()

    dowel.logger.log("ARGS: " + str(args))
    if args.n_thread is not None:
        torch.set_num_threads(args.n_thread)

    set_seed(args.seed)
    runner = OptionLocalRunner(ctxt)
    max_path_length = args.max_path_length
    if args.cp_path is not None:
        max_path_length *= args.cp_multi_step

    contextualized_make_env = functools.partial(
        make_env, args=args, max_path_length=max_path_length
    )
    env = contextualized_make_env()
    if args.cp_path is not None:
        cp_path = args.cp_path
        if not os.path.exists(cp_path):
            import glob

            cp_path = glob.glob(cp_path)[0]
        cp_dict = torch.load(cp_path, map_location="cpu")
        from garagei.envs.child_policy_env import ChildPolicyEnv

        env = ChildPolicyEnv(
            env,
            cp_dict,
            cp_action_range=1.5,
            cp_unit_length=args.cp_unit_length,
            cp_multi_step=args.cp_multi_step,
            cp_num_truncate_obs=cp_num_truncate_obs,
        )

    if args.use_model:
        model = make_model(contextualized_make_env, args.trans_minibatch_size)
    else:
        model = None
    if args.use_random_options_for_exploration:
        assert args.use_model
    example_ob = env.reset()
    print("example_ob.shape", example_ob.shape)

    if args.encoder:
        if hasattr(env, "ob_info"):
            if env.ob_info["type"] in ["hybrid", "pixel"]:
                pixel_shape = env.ob_info["pixel_shape"]
                pixel_dim = np.prod(pixel_shape)
                if env.ob_info["type"] in ["hybrid"]:
                    state_shape = env.ob_info["state_shape"]
                    assert len(state_shape) == 1
                    state_dim = state_shape[0]
                    print("state_dim", state_dim)
        else:
            pixel_shape = (64, 64, 3)
    else:
        pixel_shape = None
    device = torch.device("cuda" if args.use_gpu else "cpu")

    master_dims = [args.model_master_dim] * args.model_master_num_layers
    traj_master_dims = (
        None
        if not args.traj_encoder_dims
        else [args.traj_encoder_dims] * args.traj_encoder_num_layers
    )

    if args.model_master_nonlinearity == "relu":
        nonlinearity = torch.relu
    elif args.model_master_nonlinearity == "tanh":
        nonlinearity = torch.tanh
    else:
        nonlinearity = None

    obs_dim = env.spec.observation_space.flat_dim
    print("obs_dim:", obs_dim)
    action_dim = env.spec.action_space.flat_dim

    if args.encoder:

        def make_encoder(**kwargs):
            return Encoder(
                pixel_shape=pixel_shape,
                use_atari_torso=True if args.hilp else False,
                **kwargs,
            )

        def with_encoder(module, encoder=None):
            if encoder is None:
                encoder = make_encoder(
                    norm="layer" if args.encoder_layer_normalization else "none",
                )

            return WithEncoder(encoder=encoder, module=module)

        example_encoder = make_encoder()
        module_obs_dim = example_encoder(
            torch.as_tensor(example_ob).float().unsqueeze(0)
        ).shape[-1]
    else:
        module_obs_dim = obs_dim
    print("module_obs_dim:", module_obs_dim)

    option_info = {
        "dim_option": args.dim_option,
    }

    policy_kwargs = dict(
        name="option_policy",
        option_info=option_info,
        # clip_action=True,
    )
    module_kwargs = dict(
        hidden_sizes=(
            [args.policy_dims] * args.policy_num_layers
            if args.policy_dims is not None
            else master_dims
        ),
        layer_normalization=False,
    )
    if nonlinearity is not None:
        module_kwargs.update(hidden_nonlinearity=nonlinearity)

    module_cls = (
        GaussianMLPTwoHeadedModuleEx
        if not args.recurrent
        else GaussianLSTMTwoHeadedModuleEx
    )
    module_kwargs.update(
        dict(
            max_std=np.exp(2.0),
            normal_distribution_cls=TanhNormal,  # using TanhNormal guarantees -1~1 action range
            output_w_init=functools.partial(xavier_normal_ex, gain=1.0),
            init_std=1.0,
        )
    )

    policy_q_input_dim = (
        module_obs_dim + args.dim_option
        if not args.goal_reaching
        else module_obs_dim * 2
    )
    policy_module = module_cls(
        input_dim=policy_q_input_dim, output_dim=action_dim, **module_kwargs
    )
    if args.encoder:
        policy_encoder = make_encoder(
            hide_two_dims=True if args.asymmetric else False,
            encode_goal=True if args.goal_reaching else False,
        )
        policy_module = with_encoder(policy_module, encoder=policy_encoder)

    policy_kwargs["module"] = policy_module
    option_policy = (
        PolicyEx(**policy_kwargs)
        if not args.recurrent
        else RecurrentPolicyEx(**policy_kwargs)
    )

    output_dim = args.dim_option if not args.hierarchical else args.hierarchical_dim

    traj_encoder_obs_dim = module_obs_dim if not args.asymmetric else state_dim
    traj_nonlinearity = nonlinearity or torch.relu
    module_cls, module_kwargs = get_gaussian_module_construction(
        args,
        hidden_sizes=master_dims if not traj_master_dims else traj_master_dims,
        hidden_nonlinearity=(traj_nonlinearity if not args.hilp else torch.relu),
        w_init=torch.nn.init.xavier_uniform_,
        # output_w_init=functools.partial(torch.nn.init.uniform_, a=-1e-1, b=1e-1),
        input_dim=traj_encoder_obs_dim,
        output_dim=output_dim,
    )
    # main traj
    traj_encoder = module_cls(
        layer_normalization=True if args.traj_layer_normalization else False,
        **module_kwargs,
    )
    if args.encoder and not args.asymmetric:
        if args.spectral_normalization:
            te_encoder = make_encoder(spectral_normalization=True)
        else:
            te_encoder = None
        traj_encoder = with_encoder(traj_encoder, encoder=te_encoder)
    elif args.asymmetric:
        traj_encoder = DimensionsSelector(traj_encoder, start_dim=pixel_dim)

    if args.use_double_encoder:
        traj_encoder_2 = module_cls(**module_kwargs)
        if args.encoder and not args.asymmetric:
            if args.spectral_normalization:
                te_encoder_2 = make_encoder(spectral_normalization=True)
            else:
                te_encoder_2 = None
            traj_encoder_2 = with_encoder(traj_encoder_2, encoder=te_encoder_2)
        elif args.asymmetric:
            traj_encoder_2 = DimensionsSelector(traj_encoder_2, start_dim=pixel_dim)

    module_cls, module_kwargs = get_gaussian_module_construction(
        args,
        hidden_sizes=master_dims,
        hidden_nonlinearity=nonlinearity or torch.relu,
        w_init=torch.nn.init.xavier_uniform_,
        input_dim=obs_dim,
        output_dim=obs_dim,
        min_std=1e-6,
        max_std=1e6,
    )

    if args.dual_dist == "s2_from_s":
        dist_predictor = module_cls(**module_kwargs)
    elif args.dual_dist == "gt":

        def gt_dist(obs1, obs2):
            return torch.norm(obs2 - obs1, dim=-1, keepdim=True) / 5 * 100

        dist_predictor = gt_dist
    elif args.dual_dist == "quasimetric":
        raise NotImplementedError
    else:
        dist_predictor = None

    dual_lam = ParameterModule(torch.Tensor([np.log(args.dual_lam)]))
    if args.prevupd:
        prevupd_traj_encoder_obs_dim = (
            module_obs_dim if not args.asymmetric else state_dim
        )
        module_cls, module_kwargs = get_gaussian_module_construction(
            args,
            hidden_sizes=master_dims if not traj_master_dims else traj_master_dims,
            hidden_nonlinearity=nonlinearity or torch.relu,
            w_init=torch.nn.init.xavier_uniform_,
            input_dim=prevupd_traj_encoder_obs_dim,
            output_dim=args.prevupd_dim_exp_option,
        )
        prevupd_traj_encoder = module_cls(**module_kwargs)
        if args.encoder and not args.asymmetric:
            if args.spectral_normalization:
                te_encoder = make_encoder(spectral_normalization=True)
            else:
                te_encoder = None
            prevupd_traj_encoder = with_encoder(
                prevupd_traj_encoder, encoder=te_encoder
            )
        elif args.asymmetric:
            prevupd_traj_encoder = DimensionsSelector(
                prevupd_traj_encoder, start_dim=pixel_dim
            )

        prevupd_dual_lam_constraint = ParameterModule(
            torch.Tensor([np.log(args.dual_lam)])
        )
        prevupd_dual_lam = ParameterModule(torch.Tensor([np.log(args.dual_lam)]))

    # Skill dynamics do not support pixel obs
    sd_dim_option = args.dim_option
    skill_dynamics_obs_dim = obs_dim
    skill_dynamics_input_dim = skill_dynamics_obs_dim + sd_dim_option
    module_cls, module_kwargs = get_gaussian_module_construction(
        args,
        const_std=args.sd_const_std,
        hidden_sizes=master_dims,
        hidden_nonlinearity=nonlinearity or torch.relu,
        input_dim=skill_dynamics_input_dim,
        output_dim=skill_dynamics_obs_dim,
        min_std=0.3,
        max_std=10.0,
    )
    if args.algo == "dads":
        skill_dynamics = module_cls(**module_kwargs)
    else:
        skill_dynamics = None

    def _finalize_lr(lr):
        if lr is None:
            lr = args.common_lr
        else:
            assert bool(lr), "To specify a lr of 0, use a negative value"
        if lr < 0.0:
            dowel.logger.log(f"Setting lr to ZERO given {lr}")
            lr = 0.0
        return lr

    optimizers = {
        "option_policy": torch.optim.Adam(
            [
                {"params": option_policy.parameters(), "lr": _finalize_lr(args.lr_op)},
            ]
        ),
        "traj_encoder": torch.optim.Adam(
            [
                {"params": traj_encoder.parameters(), "lr": _finalize_lr(args.lr_te)},
            ]
        ),
        "dual_lam": torch.optim.Adam(
            [
                {"params": dual_lam.parameters(), "lr": _finalize_lr(args.dual_lr)},
            ]
        ),
    }
    if args.use_double_encoder:
        dual_lam_2 = ParameterModule(torch.Tensor([np.log(args.dual_lam)]))
        optimizers.update(
            {
                "traj_encoder_2": torch.optim.Adam(
                    [
                        {
                            "params": traj_encoder_2.parameters(),
                            "lr": _finalize_lr(args.lr_te),
                        }
                    ]
                ),
            }
        )
        optimizers.update(
            {
                "dual_lam_2": torch.optim.Adam(
                    [
                        {
                            "params": dual_lam_2.parameters(),
                            "lr": _finalize_lr(args.dual_lr),
                        },
                    ]
                ),
            }
        )
    if skill_dynamics is not None:
        optimizers.update(
            {
                "skill_dynamics": torch.optim.Adam(
                    [
                        {
                            "params": skill_dynamics.parameters(),
                            "lr": _finalize_lr(args.lr_te),
                        },
                    ]
                ),
            }
        )
    if dist_predictor is not None and hasattr(dist_predictor, "forward"):
        optimizers.update(
            {
                "dist_predictor": torch.optim.Adam(
                    [
                        {
                            "params": dist_predictor.parameters(),
                            "lr": _finalize_lr(args.lr_op),
                        },
                    ]
                ),
            }
        )
    if args.prevupd:
        optimizers.update(
            {
                "prevupd_traj_encoder": torch.optim.Adam(
                    [
                        {
                            "params": prevupd_traj_encoder.parameters(),
                            "lr": _finalize_lr(args.lr_te),
                        },
                    ]
                ),
                "prevupd_dual_lam": torch.optim.Adam(
                    [
                        {
                            "params": prevupd_dual_lam.parameters(),
                            "lr": _finalize_lr(args.dual_lr),
                        },
                    ]
                ),
                "prevupd_dual_lam_constraint": torch.optim.Adam(
                    [
                        {
                            "params": prevupd_dual_lam_constraint.parameters(),
                            "lr": _finalize_lr(args.dual_lr),
                        },
                    ]
                ),
            }
        )

    replay_buffer = PathBufferEx(
        capacity_in_transitions=int(args.sac_max_buffer_size),
        pixel_shape=pixel_shape,
        use_goal=args.hilp,
    )

    if args.algo in ["metra", "dads"]:
        qf1 = (
            ContinuousMLPQFunctionEx(
                obs_dim=policy_q_input_dim,
                action_dim=action_dim,
                hidden_sizes=(
                    [args.qf_dims] * args.qf_num_layers
                    if args.qf_dims is not None
                    else master_dims
                ),
                hidden_nonlinearity=nonlinearity or torch.relu,
                layer_normalization=True if args.q_layer_normalization else False,
            )
            if not args.recurrent
            else ContinuousLSTMQFunctionEx(
                obs_dim=policy_q_input_dim,
                action_dim=action_dim,
                hidden_sizes=(
                    [args.qf_dims] * args.qf_num_layers
                    if args.qf_dims is not None
                    else master_dims
                ),
                hidden_nonlinearity=nonlinearity or torch.relu,
            )
        )
        if args.encoder:
            qf1_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False,
                encode_goal=True if args.goal_reaching else False,
            )
            qf1 = with_encoder(qf1, encoder=qf1_encoder)
        qf2 = (
            ContinuousMLPQFunctionEx(
                obs_dim=policy_q_input_dim,
                action_dim=action_dim,
                hidden_sizes=(
                    [args.qf_dims] * args.qf_num_layers
                    if args.qf_dims is not None
                    else master_dims
                ),
                hidden_nonlinearity=nonlinearity or torch.relu,
                layer_normalization=True if args.q_layer_normalization else False,
            )
            if not args.recurrent
            else ContinuousLSTMQFunctionEx(
                obs_dim=policy_q_input_dim,
                action_dim=action_dim,
                hidden_sizes=(
                    [args.qf_dims] * args.qf_num_layers
                    if args.qf_dims is not None
                    else master_dims
                ),
                hidden_nonlinearity=nonlinearity or torch.relu,
            )
        )
        if args.encoder:
            qf2_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False,
                encode_goal=True if args.goal_reaching else False,
            )
            qf2 = with_encoder(qf2, encoder=qf2_encoder)

        log_alpha = ParameterModule(torch.Tensor([np.log(args.alpha)]))
        optimizers.update(
            {
                "qf": torch.optim.Adam(
                    [
                        {
                            "params": list(qf1.parameters()) + list(qf2.parameters()),
                            "lr": _finalize_lr(args.sac_lr_q),
                        },
                    ]
                ),
                "log_alpha": torch.optim.Adam(
                    [
                        {
                            "params": log_alpha.parameters(),
                            "lr": _finalize_lr(args.sac_lr_a),
                        },
                    ]
                ),
            }
        )

    if args.apt:
        if args.apt_icm:
            icm_rep_dim = 512
            hidden_dim = master_dims[0]
            icm = ICM(
                obs_dim=obs_dim,
                action_dim=action_dim,
                hidden_dim=hidden_dim,
                icm_rep_dim=icm_rep_dim,
            )
            icm.to(device)
            # TODO: add encoder to ICM for pixel states

            optimizers.update(
                {
                    "icm": torch.optim.Adam(
                        [
                            {
                                "params": icm.parameters(),
                                "lr": _finalize_lr(args.icm_lr),
                            },
                        ]
                    )
                }
            )
        elif args.rnd:
            assert not args.apt_icm and not args.use_traj_for_apt_rep
            rnd_rep_dim = 512
            hidden_dim = master_dims[0]
            rnd = RND(
                obs_dim=obs_dim,
                hidden_dim=hidden_dim,
                rnd_rep_dim=rnd_rep_dim,
                obs_shape=(obs_dim,),
                obs_type="states",
            ).to(device)

            optimizers.update(
                {
                    "rnd": torch.optim.Adam(
                        [
                            {
                                "params": rnd.parameters(),
                                "lr": _finalize_lr(args.common_lr),
                            },
                        ]
                    )
                }
            )
        elif args.disagreement:
            hidden_dim = 1024
            obs_dim = obs_dim
            disagreement = Disagreement(obs_dim, action_dim, hidden_dim).to(device)
            optimizers.update(
                {
                    "disagreement": torch.optim.Adam(
                        [
                            {
                                "params": disagreement.parameters(),
                                "lr": _finalize_lr(args.common_lr),
                            },
                        ]
                    )
                }
            )

        policy_kwargs = dict(
            name="exploration_policy",
            # clip_action=True,
        )
        module_kwargs = dict(
            hidden_sizes=master_dims,
            layer_normalization=False,
        )
        if nonlinearity is not None:
            module_kwargs.update(hidden_nonlinearity=nonlinearity)

        module_cls = GaussianMLPTwoHeadedModuleEx
        module_kwargs.update(
            dict(
                max_std=np.exp(2.0),
                normal_distribution_cls=TanhNormal,  # using TanhNormal guarantees -1~1 action range
                output_w_init=functools.partial(xavier_normal_ex, gain=1.0),
                init_std=1.0,
            )
        )

        policy_module = module_cls(
            input_dim=module_obs_dim * 2 if args.use_repelling else module_obs_dim,
            output_dim=action_dim,
            **module_kwargs,
        )
        if args.encoder:
            policy_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False
            )
            policy_module = with_encoder(policy_module, encoder=policy_encoder)

        policy_kwargs["module"] = policy_module
        exploration_policy = PolicyEx(**policy_kwargs)

        optimizers.update(
            {
                "exploration_policy": torch.optim.Adam(
                    [
                        {
                            "params": exploration_policy.parameters(),
                            "lr": _finalize_lr(args.exploration_lr_op),
                        },
                    ]
                ),
            }
        )

        exploration_qf1 = ContinuousMLPQFunctionEx(
            obs_dim=module_obs_dim * 2 if args.use_repelling else module_obs_dim,
            action_dim=action_dim,
            hidden_sizes=(
                [args.exp_qf_dims] * args.exp_qf_num_layers
                if args.exp_qf_dims is not None
                else master_dims
            ),
            hidden_nonlinearity=nonlinearity or torch.relu,
            layer_normalization=True if args.exp_q_layer_normalization else False,
            # empirically, exploration q function should not have layernorm? But gradient explodes with humanoid. So let's test this too
        )
        print("module_obs_dim:", module_obs_dim)
        print("action_dim:", action_dim)
        if args.encoder:
            exploration_qf1_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False
            )
            exploration_qf1 = with_encoder(
                exploration_qf1, encoder=exploration_qf1_encoder
            )
        exploration_qf2 = ContinuousMLPQFunctionEx(
            obs_dim=module_obs_dim * 2 if args.use_repelling else module_obs_dim,
            action_dim=action_dim,
            hidden_sizes=(
                [args.exp_qf_dims] * args.exp_qf_num_layers
                if args.exp_qf_dims is not None
                else master_dims
            ),
            hidden_nonlinearity=nonlinearity or torch.relu,
            layer_normalization=True if args.exp_q_layer_normalization else False,
        )
        if args.encoder:
            exploration_qf2_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False
            )
            exploration_qf2 = with_encoder(
                exploration_qf2, encoder=exploration_qf2_encoder
            )
        exploration_log_alpha = ParameterModule(torch.Tensor([np.log(args.alpha)]))

        optimizers.update(
            {
                "exploration_qf": torch.optim.Adam(
                    [
                        {
                            "params": list(exploration_qf1.parameters())
                            + list(exploration_qf2.parameters()),
                            "lr": _finalize_lr(args.exploration_sac_lr_q),
                        },
                    ]
                ),
                "exploration_log_alpha": torch.optim.Adam(
                    [
                        {
                            "params": exploration_log_alpha.parameters(),
                            "lr": _finalize_lr(args.exploration_sac_lr_a),
                        },
                    ]
                ),
            }
        )
    elif args.prevupd:
        # needs prevupd_traj_encoder too

        policy_kwargs = dict(
            name="exploration_policy",
            # clip_action=True,
        )
        module_kwargs = dict(
            hidden_sizes=(
                [args.exp_policy_dims] * args.exp_policy_num_layers
                if args.exp_policy_dims is not None
                else master_dims
            ),
            layer_normalization=False,
        )
        if nonlinearity is not None:
            module_kwargs.update(hidden_nonlinearity=nonlinearity)

        module_cls = GaussianMLPTwoHeadedModuleEx
        module_kwargs.update(
            dict(
                max_std=np.exp(2.0),
                normal_distribution_cls=TanhNormal,  # using TanhNormal guarantees -1~1 action range
                output_w_init=functools.partial(xavier_normal_ex, gain=1.0),
                init_std=1.0,
            )
        )

        policy_module = module_cls(
            input_dim=module_obs_dim + args.prevupd_dim_exp_option,
            output_dim=action_dim,
            **module_kwargs,
        )
        if args.encoder:
            policy_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False
            )
            policy_module = with_encoder(policy_module, encoder=policy_encoder)

        policy_kwargs["module"] = policy_module
        exploration_policy = PolicyEx(**policy_kwargs)

        optimizers.update(
            {
                "exploration_policy": torch.optim.Adam(
                    [
                        {
                            "params": exploration_policy.parameters(),
                            "lr": _finalize_lr(args.lr_op),
                        },
                    ]
                ),
            }
        )

        exploration_qf1 = ContinuousMLPQFunctionEx(
            obs_dim=module_obs_dim + args.prevupd_dim_exp_option,
            action_dim=action_dim,
            hidden_sizes=master_dims,
            hidden_nonlinearity=nonlinearity or torch.relu,
        )
        if args.encoder:
            exploration_qf1_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False
            )
            exploration_qf1 = with_encoder(
                exploration_qf1, encoder=exploration_qf1_encoder
            )
        exploration_qf2 = ContinuousMLPQFunctionEx(
            obs_dim=module_obs_dim + args.prevupd_dim_exp_option,
            action_dim=action_dim,
            hidden_sizes=master_dims,
            hidden_nonlinearity=nonlinearity or torch.relu,
        )
        if args.encoder:
            exploration_qf2_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False
            )
            exploration_qf2 = with_encoder(
                exploration_qf2, encoder=exploration_qf2_encoder
            )
        exploration_log_alpha = ParameterModule(torch.Tensor([np.log(args.alpha)]))

        optimizers.update(
            {
                "exploration_qf": torch.optim.Adam(
                    [
                        {
                            "params": list(exploration_qf1.parameters())
                            + list(exploration_qf2.parameters()),
                            "lr": _finalize_lr(args.sac_lr_q),
                        },
                    ]
                ),
                "exploration_log_alpha": torch.optim.Adam(
                    [
                        {
                            "params": exploration_log_alpha.parameters(),
                            "lr": _finalize_lr(args.sac_lr_a),
                        },
                    ]
                ),
            }
        )

    if args.hierarchical:
        hierarchical_qf1 = DiscreteMLPQFunctionEx(
            obs_dim=module_obs_dim + args.hierarchical_dim,
            action_dim=args.dim_option,
            hidden_sizes=master_dims,
            hidden_nonlinearity=nonlinearity or torch.relu,  # TODO: check master dims
        )
        # print("module_obs_dim:", module_obs_dim)
        # print("action_dim:", action_dim)
        if args.encoder:
            hierarchical_qf1_encoder = make_encoder(
                hide_two_dims=True if args.asymmetric else False
            )
            hierarchical_qf1 = with_encoder(
                hierarchical_qf1, encoder=hierarchical_qf1_encoder
            )

        optimizers.update(
            {
                "hierarchical_qf": torch.optim.Adam(
                    [
                        {
                            "params": hierarchical_qf1.parameters(),
                            "lr": _finalize_lr(args.hierarchical_lr),
                        },
                    ]
                ),
            }
        )

    optimizer = OptimizerGroupWrapper(
        optimizers=optimizers,
        max_optimization_epochs=None,
    )

    algo_kwargs = dict(
        env_name=args.env,
        algo=args.algo,
        env_spec=env.spec,
        option_policy=option_policy,
        traj_encoder=traj_encoder,
        skill_dynamics=skill_dynamics,
        dist_predictor=dist_predictor,
        dual_lam=dual_lam,
        optimizer=optimizer,
        alpha=args.alpha,
        max_path_length=args.max_path_length,
        n_epochs_per_eval=args.n_epochs_per_eval,
        n_epochs_per_log=args.n_epochs_per_log,
        n_epochs_per_tb=args.n_epochs_per_log,
        n_epochs_per_save=args.n_epochs_per_save,
        n_epochs_per_pt_save=args.n_epochs_per_pt_save,
        n_epochs_per_pkl_update=(
            args.n_epochs_per_eval
            if args.n_epochs_per_pkl_update is None
            else args.n_epochs_per_pkl_update
        ),
        dim_option=args.dim_option,
        num_random_trajectories=args.num_random_trajectories,
        num_video_repeats=args.num_video_repeats,
        eval_record_video=args.eval_record_video,
        video_skip_frames=args.video_skip_frames,
        eval_plot_axis=args.eval_plot_axis,
        name="METRA",
        device=device,
        sample_cpu=args.sample_cpu,
        num_train_per_epoch=1,
        sd_batch_norm=args.sd_batch_norm,
        skill_dynamics_obs_dim=skill_dynamics_obs_dim,
        trans_minibatch_size=args.trans_minibatch_size,
        trans_optimization_epochs=args.trans_optimization_epochs,
        discount=args.sac_discount,
        exploration_sac_discount=args.exploration_sac_discount,
        discrete=args.discrete,
        unit_length=args.unit_length,
        use_pure_rewards=args.use_pure_rewards,
        model=model,
        use_random_options_for_exploration=args.use_random_options_for_exploration,
        dot_penalty=args.dot_penalty,
        option_freq=args.option_freq,
        exploration_type=args.exploration_type,
        perpendicular=args.perpendicular,
        apt=args.apt,
        icm=icm if args.apt_icm else None,
        exploration_policy=exploration_policy if args.apt or args.prevupd else None,
        exploration_qf1=exploration_qf1 if args.apt or args.prevupd else None,
        exploration_qf2=exploration_qf2 if args.apt or args.prevupd else None,
        exploration_log_alpha=(
            exploration_log_alpha if args.apt or args.prevupd else None
        ),
        use_traj_for_apt_rep=args.use_traj_for_apt_rep,
        use_start_policy=args.use_start_policy,
        hierarchical_qf1=hierarchical_qf1 if args.hierarchical else None,
        hierarchical_dim=args.hierarchical_dim,
        prevupd=args.prevupd,
        prevupd_dim_exp_option=args.prevupd_dim_exp_option if args.prevupd else None,
        prevupd_traj_encoder=prevupd_traj_encoder if args.prevupd else None,
        prevupd_dual_lam=prevupd_dual_lam if args.prevupd else None,
        prevupd_dual_lam_constraint=(
            prevupd_dual_lam_constraint if args.prevupd else None
        ),
        prevupd_freq=args.prevupd_freq if args.prevupd else None,
        prevupd_use_same_policy=args.prevupd_use_same_policy if args.prevupd else None,
        debug_noconst=args.debug_noconst,
        debug_val=args.debug_val,
        no_plot=args.no_plot,
        hilp=args.hilp,
        hilp_expectile=args.hilp_expectile if args.hilp else None,
        hilp_traj_encoder_tau=args.hilp_traj_encoder_tau if args.hilp else None,
        hilp_qrl=args.hilp_qrl if args.hilp else None,
        hilp_p_trajgoal=args.hilp_p_trajgoal if args.hilp else None,
        hilp_discount=args.hilp_discount if args.hilp else None,
        traj_encoder_2=traj_encoder_2 if args.use_double_encoder else None,
        dual_lam_2=dual_lam_2 if args.use_double_encoder else None,
        goal_reaching=args.goal_reaching,
        use_goal_checker=args.use_goal_checker,
        frame_stack=args.frame_stack,
        save_replay_buffer=args.save_replay_buffer,
        visualize_rewards=args.visualize_rewards,
        use_target_traj=args.use_target_traj,
        use_repelling=args.use_repelling,
        knn_k=args.knn_k,
        rnd=rnd if args.rnd else None,
        disagreement=disagreement if args.disagreement else None,
    )

    skill_common_args = dict(
        qf1=qf1,
        qf2=qf2,
        log_alpha=log_alpha,
        tau=args.sac_tau,
        scale_reward=args.sac_scale_reward,
        target_coef=args.sac_target_coef,
        replay_buffer=replay_buffer,
        min_buffer_size=args.sac_min_buffer_size,
        inner=args.inner,
        num_alt_samples=args.num_alt_samples,
        split_group=args.split_group,
        dual_reg=args.dual_reg,
        dual_slack=args.dual_slack,
        dual_dist=args.dual_dist,
        pixel_shape=pixel_shape,
    )

    if args.algo == "metra":
        if args.recurrent:
            algo = RecurrentMETRA(
                **algo_kwargs,
                **skill_common_args,
            )
        else:
            algo = METRA(
                **algo_kwargs,
                **skill_common_args,
            )
    elif args.algo == "dads":
        algo = DADS(
            **algo_kwargs,
            **skill_common_args,
        )
    else:
        raise NotImplementedError

    algo.initial_state = example_ob
    if args.sample_cpu:
        algo.option_policy.cpu()
        algo.traj_encoder.cpu()
    else:
        algo.option_policy.to(device)
        algo.traj_encoder.to(device)
    runner.setup(
        algo=algo,
        env=env,
        make_env=contextualized_make_env,
        sampler_cls=OptionMultiprocessingSampler,
        sampler_args=dict(n_thread=args.n_thread),
        n_workers=args.n_parallel,
        worker_args=dict(return_sim_states=args.use_model),
    )
    algo.option_policy.to(device)
    algo.traj_encoder.to(device)

    if args.hilp and args.debug_noconst == 6:
        import d4rl, gym

        def get_dataset(
            env: gym.Env,
            env_name: str,
            clip_to_eps: bool = True,
            eps: float = 1e-5,
            dataset=None,
            filter_terminals=False,
            obs_dtype=np.float32,
            goal_conditioned=True,
        ):
            if dataset is None:
                dataset = d4rl.qlearning_dataset(env)

            if clip_to_eps:
                lim = 1 - eps
                dataset["actions"] = np.clip(dataset["actions"], -lim, lim)

            if goal_conditioned:
                dataset["terminals"][-1] = 1

            if filter_terminals:
                # drop terminal transitions
                non_last_idx = np.nonzero(~dataset["terminals"])[0]
                last_idx = np.nonzero(dataset["terminals"])[0]
                penult_idx = last_idx - 1
                new_dataset = dict()
                for k, v in dataset.items():
                    if k == "terminals":
                        v[penult_idx] = 1
                    new_dataset[k] = v[non_last_idx]
                dataset = new_dataset

            if "antmaze" in env_name:
                dones_float = np.zeros_like(dataset["rewards"])
                traj_ends = np.zeros_like(dataset["rewards"])

                for i in range(len(dones_float) - 1):
                    traj_end = (
                        np.linalg.norm(
                            dataset["observations"][i + 1]
                            - dataset["next_observations"][i]
                        )
                        > 1e-6
                    )
                    traj_ends[i] = traj_end
                    if goal_conditioned:
                        dones_float[i] = int(traj_end)
                    else:
                        dones_float[i] = int(traj_end or dataset["terminals"][i] == 1.0)
                dones_float[-1] = 1
                traj_ends[-1] = 1
            else:
                dones_float = dataset["terminals"].copy()
                traj_ends = dataset["terminals"].copy()

            observations = dataset["observations"].astype(obs_dtype)
            next_observations = dataset["next_observations"].astype(obs_dtype)

            if goal_conditioned:
                masks = 1.0 - dones_float
            else:
                masks = 1.0 - dataset["terminals"].astype(np.float32)

            return dict(
                observations=observations,
                actions=dataset["actions"].astype(np.float32),
                rewards=dataset["rewards"].astype(np.float32),
                masks=masks,
                dones_float=dones_float.astype(np.float32),
                next_observations=next_observations,
                traj_ends=traj_ends,
            )

        if args.env == "antmaze-original":
            dataset = get_dataset(env.unwrapped, "antmaze")
        else:
            dataset = d4rl.qlearning_dataset(env.unwrapped, terminate_on_end=True)
            dataset["dones_float"] = dataset["terminals"].copy()

        print("dones_float", np.sum(dataset["dones_float"]))

        # now parse the data as trajectories
        # reference: from process_samples
        obs = dataset["observations"]
        assert obs.shape[-1] == 4 or obs.shape[-1] == 29, obs.shape
        next_obs = dataset["next_observations"]
        actions = dataset["actions"]
        rewards = dataset["rewards"]
        terminals = dataset["dones_float"]

        if args.env == "antmaze-original":
            init_pos = np.zeros_like(obs)
            init_pos[:, :2] = np.array(
                [env.unwrapped._init_torso_x, env.unwrapped._init_torso_y]
            ) - np.array([3.5, 3])
            obs = obs - init_pos
            next_obs = next_obs - init_pos

        assert terminals.shape[0] == obs.shape[0], (terminals.shape, obs.shape)

        terminals_idx = np.flatnonzero(terminals)
        assert terminals_idx.ndim == 1
        print("terminals shape", terminals_idx.shape)

        print("processing data...")
        prev_idx = 0
        for i, terminal_idx in enumerate(terminals_idx):
            replay_buffer.add_path(
                {
                    "obs": obs[prev_idx : terminal_idx + 1],
                    "next_obs": next_obs[prev_idx : terminal_idx + 1],
                    "actions": actions[prev_idx : terminal_idx + 1],
                    "rewards": rewards[prev_idx : terminal_idx + 1, None],
                    "dones": terminals[prev_idx : terminal_idx + 1, None],
                }
            )
            prev_idx = terminal_idx + 1
            if i % 10000 == 0:
                print(i)
        print("done")
        print("total number of transitions:", replay_buffer.n_transitions_stored)
        assert replay_buffer.n_transitions_stored >= args.sac_min_buffer_size, (
            replay_buffer.n_transitions_stored,
            args.sac_min_buffer_size,
        )
        from matplotlib import figure

        fig = figure.Figure()
        ax = fig.add_subplot()
        env.scatter_trajectory(obs[:10000, :2], color="r", ax=ax)
        fig.savefig("viz_maze_dataset.png")

    if args.restore_buffer_from is not None:
        loaded_runner = OptionLocalRunner(ctxt)
        loaded_runner.restore(
            from_dir=args.restore_buffer_from,
            make_env=contextualized_make_env,
        )
        print("replay buffer status:")
        print(
            loaded_runner._algo.replay_buffer,
            loaded_runner._algo.replay_buffer.n_transitions_stored,
        )
        runner._algo.replay_buffer = loaded_runner._algo.replay_buffer
        runner._algo.replay_buffer.debug_noconst = args.debug_noconst
        del loaded_runner
        torch.cuda.empty_cache()

    return runner


@wrap_experiment(log_dir=get_log_dir(), name=get_exp_name()[0])
def run(ctxt=None):
    runner = get_runner(args, ctxt=ctxt)

    runner.train(n_epochs=args.n_epochs, batch_size=args.traj_batch_size)


if __name__ == "__main__":
    mp.set_start_method(START_METHOD)
    run()
