"""
Helper script to generate jsons for reproducing paper experiments.

Args:
    config_dir (str): Directory where generated configs will be placed. 
        Defaults to 'paper' subfolder in exps folder of repository

    dataset_dir (str): Base dataset directory where released datasets can be
        found on disk. Defaults to datasets folder in repository.

    output_dir (str): Base output directory for all training runs that will be 
        written to generated configs.

Example usage:
    # Assume datasets alredy exist in robomimic/../datasets folder. Configs will be generated under robomimic/exps/paper
    python generate_paper_configs.py --output_dir /tmp/experiment_results

    # Specify where datasets exist, and specify where configs should be generated.
    python generate_paper_configs.py --config_dir /tmp/configs --dataset_dir /tmp/datasets --output_dir /tmp/experiment_results
"""
import os
import argparse
import robomimic
from robomimic import DATASET_REGISTRY
from robomimic.config import Config, BCConfig, BCQConfig, CQLConfig, HBCConfig, IRISConfig, config_factory


def modify_config_for_default_low_dim_exp(config):
    """
    Modifies a Config object with experiment, training, and observation settings that
    were used across all low-dimensional experiments by default.

    Args:
        config (Config instance): config to modify
    """

    with config.experiment.values_unlocked():
        # save model during every evaluation (every 50 epochs)
        config.experiment.save.enabled = True
        config.experiment.save.every_n_epochs = 50

        # every epoch is 100 gradient steps, and validation epoch is 10 gradient steps
        config.experiment.epoch_every_n_steps = 100
        config.experiment.validation_epoch_every_n_steps = 10

        # do 50 evaluation rollouts every 50 epochs
        # NOTE: horizon will generally get set depending on the task and dataset type
        config.experiment.rollout.enabled = True
        config.experiment.rollout.n = 50
        config.experiment.rollout.horizon = 400
        config.experiment.rollout.rate = 50
        config.experiment.rollout.warmstart = 0
        config.experiment.rollout.terminate_on_success = True

    with config.train.values_unlocked():
        # assume entire dataset can fit in memory
        config.train.num_data_workers = 0
        config.train.hdf5_cache_mode = "all"

        # batch size 100 and 2000 training epochs
        config.train.batch_size = 100
        config.train.num_epochs = 2000

    with config.observation.values_unlocked():
        # default observation is eef pose, gripper finger position, and object information,
        # all of which are low-dim. 
        default_low_dim_obs = [
            "robot0_eef_pos", 
            "robot0_eef_quat", 
            "robot0_gripper_qpos", 
            "object",
        ]
        # handle hierarchical observation configs
        if config.algo_name == "hbc":
            configs_to_set = [
                config.observation.actor.modalities.obs,
                config.observation.planner.modalities.obs,
                config.observation.planner.modalities.subgoal,
            ]
        elif config.algo_name == "iris":
            configs_to_set = [
                config.observation.actor.modalities.obs,
                config.observation.value_planner.planner.modalities.obs,
                config.observation.value_planner.planner.modalities.subgoal,
                config.observation.value_planner.value.modalities.obs,
            ]
        else:
            configs_to_set = [config.observation.modalities.obs]
        # set all observations / subgoals to use the correct low-dim modalities
        for cfg in configs_to_set:
            cfg.low_dim = list(default_low_dim_obs)
            cfg.rgb = []

    return config


def modify_config_for_default_image_exp(config):
    """
    Modifies a Config object with experiment, training, and observation settings that
    were used across all image experiments by default.

    Args:
        config (Config instance): config to modify
    """
    assert config.algo_name not in ["hbc", "iris"], "no image training for HBC and IRIS"

    with config.experiment.values_unlocked():
        # save model during every evaluation (every 20 epochs)
        config.experiment.save.enabled = True
        config.experiment.save.every_n_epochs = 20

        # every epoch is 500 gradient steps, and validation epoch is 50 gradient steps
        config.experiment.epoch_every_n_steps = 500
        config.experiment.validation_epoch_every_n_steps = 50

        # do 50 evaluation rollouts every 20 epochs
        # NOTE: horizon will generally get set depending on the task and dataset type
        config.experiment.rollout.enabled = True
        config.experiment.rollout.n = 50
        config.experiment.rollout.horizon = 400
        config.experiment.rollout.rate = 20
        config.experiment.rollout.warmstart = 0
        config.experiment.rollout.terminate_on_success = True

    with config.train.values_unlocked():
        # only cache low-dim info, and use 2 data workers to increase fetch speed for image obs
        config.train.num_data_workers = 2
        config.train.hdf5_cache_mode = "low_dim"

        # batch size 16 and 600 training epochs
        config.train.batch_size = 16
        config.train.num_epochs = 600


    with config.observation.values_unlocked():
        # default low-dim observation is eef pose, gripper finger position
        # default image observation is external camera and wrist camera
        config.observation.modalities.obs.low_dim = [
            "robot0_eef_pos", 
            "robot0_eef_quat", 
            "robot0_gripper_qpos", 
        ]
        config.observation.modalities.obs.rgb = [
            "agentview_image",
            "robot0_eye_in_hand_image",
        ]
        config.observation.modalities.goal.low_dim = []
        config.observation.modalities.goal.rgb = []

        # default image encoder architecture is ResNet with spatial softmax
        config.observation.encoder.rgb.core_class = "VisualCore"
        config.observation.encoder.rgb.core_kwargs.feature_dimension = 64
        config.observation.encoder.rgb.core_kwargs.backbone_class = 'ResNet18Conv'                         # ResNet backbone for image observations (unused if no image observations)
        config.observation.encoder.rgb.core_kwargs.backbone_kwargs.pretrained = False                # kwargs for visual core
        config.observation.encoder.rgb.core_kwargs.backbone_kwargs.input_coord_conv = False
        config.observation.encoder.rgb.core_kwargs.pool_class = "SpatialSoftmax"                # Alternate options are "SpatialMeanPool" or None (no pooling)
        config.observation.encoder.rgb.core_kwargs.pool_kwargs.num_kp = 32                      # Default arguments for "SpatialSoftmax"
        config.observation.encoder.rgb.core_kwargs.pool_kwargs.learnable_temperature = False    # Default arguments for "SpatialSoftmax"
        config.observation.encoder.rgb.core_kwargs.pool_kwargs.temperature = 1.0                # Default arguments for "SpatialSoftmax"
        config.observation.encoder.rgb.core_kwargs.pool_kwargs.noise_std = 0.0

        # observation randomizer class - set to None to use no randomization, or 'CropRandomizer' to use crop randomization
        config.observation.encoder.rgb.obs_randomizer_class = "CropRandomizer"

        # kwargs for observation randomizers (for the CropRandomizer, this is size and number of crops)
        config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 76
        config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 76
        config.observation.encoder.rgb.obs_randomizer_kwargs.num_crops = 1
        config.observation.encoder.rgb.obs_randomizer_kwargs.pos_enc = False

    return config


def modify_config_for_dataset(config, task_name, dataset_type, hdf5_type, base_dataset_dir, filter_key=None):
    """
    Modifies a Config object with experiment, training, and observation settings to
    correspond to experiment settings for the dataset collected on @task_name with
    dataset source @dataset_type (e.g. ph, mh, mg), and hdf5 type @hdf5_type (e.g. low_dim
    or image).

    Args:
        config (Config instance): config to modify

        task_name (str): identify task that dataset was collected on

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        filter_key (str): if not None, use the provided filter key to select a subset of the
            provided dataset
    """
    assert task_name in DATASET_REGISTRY, \
        "task {} not found in dataset registry!".format(task_name)
    assert dataset_type in DATASET_REGISTRY[task_name], \
        "dataset type {} not found for task {} in dataset registry!".format(dataset_type, task_name)
    assert hdf5_type in DATASET_REGISTRY[task_name][dataset_type], \
        "hdf5 type {} not found for dataset type {} and task {} in dataset registry!".format(hdf5_type, dataset_type, task_name)

    is_real_dataset = "real" in task_name
    if is_real_dataset:
        assert config.algo_name == "bc", "we only ran BC-RNN on real robot"
    else:
        assert hdf5_type != "raw", "cannot train on raw demonstrations"

    with config.experiment.values_unlocked():

        # look up rollout evaluation horizon in registry and set it
        config.experiment.rollout.horizon = DATASET_REGISTRY[task_name][dataset_type][hdf5_type]["horizon"]

        if dataset_type == "mg":
            # machine-generated datasets did not use validation
            config.experiment.validate = False
        else:
            # all other datasets used validation
            config.experiment.validate = True

        if is_real_dataset:
            # no evaluation rollouts for real robot training
            config.experiment.rollout.enabled = False

    with config.train.values_unlocked():
        # set dataset path and possibly filter keys
        url = DATASET_REGISTRY[task_name][dataset_type][hdf5_type]["url"]
        if url is None:
            # infer file_name
            if task_name in ["lift", "can", "square", "tool_hang", "transport"]:
                file_name = "{}_v141.hdf5".format(hdf5_type)
            elif task_name in ["lift_real", "can_real", "tool_hang_real"]:
                file_name = "{}.hdf5".format(hdf5_type)
            else:
                raise ValueError("Unknown dataset type")
        else:
            file_name = url.split("/")[-1]
        config.train.data = os.path.join(base_dataset_dir, task_name, dataset_type, file_name)
        config.train.hdf5_filter_key = None if filter_key is None else filter_key
        config.train.hdf5_validation_filter_key = None
        if config.experiment.validate:
            # set train and valid keys for validation
            config.train.hdf5_filter_key = "train" if filter_key is None else "{}_train".format(filter_key)
            config.train.hdf5_validation_filter_key = "valid" if filter_key is None else "{}_valid".format(filter_key)

    with config.observation.values_unlocked():
        # maybe modify observation names and randomization sizes (since image size might be different)

        if is_real_dataset:
            # modify observation names for real robot datasets
            config.observation.modalities.obs.low_dim = [
                "ee_pose", 
                "gripper_position", 
            ]

            if task_name == "tool_hang_real":
                # side and wrist camera
                config.observation.modalities.obs.rgb = [
                    "image_side",
                    "image_wrist",
                ]
                # 240x240 images -> crops should be 216x216
                config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 216
                config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 216
            else:
                # front and wrist camera
                config.observation.modalities.obs.rgb = [
                    "image",
                    "image_wrist",
                ]
                # 120x120 images -> crops should be 108x108
                config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 108
                config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 108

        elif hdf5_type in ["image", "image_sparse", "image_dense"]:
            if task_name == "transport":
                # robot proprioception per arm
                config.observation.modalities.obs.low_dim = [
                    "robot0_eef_pos", 
                    "robot0_eef_quat", 
                    "robot0_gripper_qpos", 
                    "robot1_eef_pos", 
                    "robot1_eef_quat", 
                    "robot1_gripper_qpos", 
                ]

                # shoulder and wrist cameras per arm
                config.observation.modalities.obs.rgb = [
                    "shouldercamera0_image",
                    "robot0_eye_in_hand_image",
                    "shouldercamera1_image",
                    "robot1_eye_in_hand_image",
                ]
            elif task_name == "tool_hang":
                # side and wrist camera
                config.observation.modalities.obs.rgb = [
                    "sideview_image",
                    "robot0_eye_in_hand_image",
                ]
                # 240x240 images -> crops should be 216x216
                config.observation.encoder.rgb.obs_randomizer_kwargs.crop_height = 216
                config.observation.encoder.rgb.obs_randomizer_kwargs.crop_width = 216

        elif hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
            if task_name == "transport":
                # robot proprioception per arm
                default_low_dim_obs = [
                    "robot0_eef_pos", 
                    "robot0_eef_quat", 
                    "robot0_gripper_qpos", 
                    "robot1_eef_pos", 
                    "robot1_eef_quat", 
                    "robot1_gripper_qpos", 
                    "object",
                ]
                # handle hierarchical observation configs
                if config.algo_name == "hbc":
                    configs_to_set = [
                        config.observation.actor.modalities.obs,
                        config.observation.planner.modalities.obs,
                        config.observation.planner.modalities.subgoal,
                    ]
                elif config.algo_name == "iris":
                    configs_to_set = [
                        config.observation.actor.modalities.obs,
                        config.observation.value_planner.planner.modalities.obs,
                        config.observation.value_planner.planner.modalities.subgoal,
                        config.observation.value_planner.value.modalities.obs,
                    ]
                else:
                    configs_to_set = [config.observation.modalities.obs]
                # set all observations / subgoals to use the correct low-dim modalities
                for obs_key_config in configs_to_set:
                    obs_key_config.low_dim = list(default_low_dim_obs)
                    obs_key_config.rgb = []

    return config


def modify_bc_config_for_dataset(config, task_name, dataset_type, hdf5_type):
    """
    Modifies a BCConfig object for training on a particular kind of dataset. This function
    just sets algorithm hyperparameters in the algo config depending on the kind of 
    dataset.

    Args:
        config (BCConfig instance): config to modify

        task_name (str): identify task that dataset was collected on. Only used to distinguish
            between simulation and real-world, for an assert statement

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 
    """
    assert isinstance(config, BCConfig), "must be BCConfig"
    assert config.algo_name == "bc", "must be BCConfig"
    assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
    is_real_dataset = "real" in task_name
    if not is_real_dataset:
        assert hdf5_type != "raw", "cannot train on raw demonstrations"

    with config.algo.values_unlocked():
        # base parameters that may get modified
        config.algo.optim_params.policy.learning_rate.initial = 1e-4            # learning rate 1e-4
        config.algo.actor_layer_dims = (1024, 1024)                             # MLP size (1024, 1024)
        config.algo.gmm.enabled = True                                          # enable GMM

        if dataset_type == "mg":
            # machine-generated datasets don't use GMM
            config.algo.gmm.enabled = False                                     # disable GMM
            if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
                # low-dim mg uses LR 1e-3
                config.algo.optim_params.policy.learning_rate.initial = 1e-3    # learning rate 1e-3

    return config


def modify_bc_rnn_config_for_dataset(config, task_name, dataset_type, hdf5_type):
    """
    Modifies a BCConfig object for training on a particular kind of dataset. This function
    just sets algorithm hyperparameters in the algo config depending on the kind of 
    dataset.

    Args:
        config (BCConfig instance): config to modify

        task_name (str): identify task that dataset was collected on. Only used to distinguish
            between simulation and real-world, for an assert statement

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 
    """
    assert isinstance(config, BCConfig), "must be BCConfig"
    assert config.algo_name == "bc", "must be BCConfig"
    assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
    is_real_dataset = "real" in task_name
    if not is_real_dataset:
        assert hdf5_type != "raw", "cannot train on raw demonstrations"

    with config.train.values_unlocked():
        # make sure RNN is enabled with sequence length 10
        config.train.seq_length = 10

    with config.algo.values_unlocked():
        # make sure RNN is enabled with sequence length 10
        config.algo.rnn.enabled = True
        config.algo.rnn.horizon = 10

        # base parameters that may get modified
        config.algo.optim_params.policy.learning_rate.initial = 1e-4            # learning rate 1e-4
        config.algo.actor_layer_dims = ()                                       # no MLP layers between rnn layer and output
        config.algo.gmm.enabled = True                                          # enable GMM
        config.algo.rnn.hidden_dim = 400                                        # rnn dim 400

        if dataset_type == "mg":
            # update hyperparams for machine-generated datasets
            config.algo.gmm.enabled = False                                     # disable GMM
            if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
                # image datasets use RNN dim 1000
                config.algo.rnn.hidden_dim = 1000                               # rnn dim 1000
        else:
            # update hyperparams for all other dataset types (ph, mh, paired)
            if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
                # image datasets use RNN dim 1000
                config.algo.rnn.hidden_dim = 1000                               # rnn dim 1000

    return config


def modify_bcq_config_for_dataset(config, task_name, dataset_type, hdf5_type):
    """
    Modifies a BCQConfig object for training on a particular kind of dataset. This function
    just sets algorithm hyperparameters in the algo config depending on the kind of 
    dataset.

    Args:
        config (BCQConfig instance): config to modify

        task_name (str): identify task that dataset was collected on. Only used to distinguish
            between simulation and real-world, for an assert statement

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 
    """
    assert isinstance(config, BCQConfig), "must be BCQConfig"
    assert config.algo_name == "bcq", "must be BCQConfig"
    assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
    is_real_dataset = "real" in task_name
    assert not is_real_dataset, "we only ran BC-RNN on real robot"
    if not is_real_dataset:
        assert hdf5_type != "raw", "cannot train on raw demonstrations"

    with config.algo.values_unlocked():
        # base parameters that may get modified further
        config.algo.optim_params.critic.learning_rate.initial = 1e-4                # all learning rates 1e-3
        config.algo.optim_params.action_sampler.learning_rate.initial = 1e-4
        config.algo.optim_params.actor.learning_rate.initial = 1e-3
        config.algo.actor.enabled = False                                           # disable actor by default
        config.algo.action_sampler.vae.enabled = True                               # use VAE action sampler
        config.algo.action_sampler.gmm.enabled = False
        config.algo.action_sampler.vae.kl_weight = 0.05                             # beta 0.05 for VAE
        config.algo.action_sampler.vae.latent_dim = 14                              # latent dim 14
        config.algo.action_sampler.vae.prior.learn = False                          # N(0, 1) prior
        config.algo.critic.layer_dims = (300, 400)                                  # all MLP sizes at (300, 400)
        config.algo.action_sampler.vae.encoder_layer_dims = (300, 400)
        config.algo.action_sampler.vae.decoder_layer_dims = (300, 400)
        config.algo.actor.layer_dims = (300, 400)
        config.algo.target_tau = 5e-4                                               # tau 5e-4
        config.algo.discount = 0.99                                                 # discount 0.99
        config.algo.critic.num_action_samples = 10                                  # number of action sampler samples at train and test
        config.algo.critic.num_action_samples_rollout = 100

        if dataset_type == "mg":
            # update hyperparams for machine-generated datasets
            config.algo.optim_params.critic.learning_rate.initial = 1e-3            # all learning rates 1e-3
            config.algo.optim_params.action_sampler.learning_rate.initial = 1e-3
            config.algo.optim_params.actor.learning_rate.initial = 1e-3
            config.algo.action_sampler.vae.kl_weight = 0.5                          # beta 0.5 for VAE
            config.algo.target_tau = 5e-3                                           # tau 5e-3

            if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
                # enable actor only on low-dim
                config.algo.actor.enabled = True
        else:
            # make some modifications where needed for human datasets
            if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
                if dataset_type in ["mh", "paired"]:
                    # low-dim, MH had higher layer sizes
                    config.algo.critic.layer_dims = (1024, 1024)
                    config.algo.action_sampler.vae.encoder_layer_dims = (1024, 1024)
                    config.algo.action_sampler.vae.decoder_layer_dims = (1024, 1024)
                    config.algo.action_sampler.vae.prior_layer_dims = (1024, 1024)

                    config.algo.action_sampler.vae.kl_weight = 0.5

                    # use learned GMM prior for MH dataset
                    config.algo.action_sampler.vae.prior.learn = True
                    config.algo.action_sampler.vae.prior.is_conditioned = True
                    config.algo.action_sampler.vae.prior.use_gmm = True
                    config.algo.action_sampler.vae.prior.gmm_learn_weights = True
            else:
                if dataset_type == "ph":
                    # image, PH used higher critic LR of 1e-3
                    config.algo.optim_params.critic.learning_rate.initial = 1e-3
                # image datasets used bigger VAE
                config.algo.action_sampler.vae.encoder_layer_dims = (1024, 1024)
                config.algo.action_sampler.vae.decoder_layer_dims = (1024, 1024)
                if dataset_type in ["mh", "paired"]:
                    # image, MH also had bigger critic
                    config.algo.critic.layer_dims = (1024, 1024)

    return config


def modify_cql_config_for_dataset(config, task_name, dataset_type, hdf5_type):
    """
    Modifies a CQLConfig object for training on a particular kind of dataset. This function
    just sets algorithm hyperparameters in the algo config depending on the kind of 
    dataset.

    Args:
        config (CQLConfig instance): config to modify

        task_name (str): identify task that dataset was collected on. Only used to distinguish
            between simulation and real-world, for an assert statement

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 
    """
    assert isinstance(config, CQLConfig), "must be CQLConfig"
    assert config.algo_name == "cql", "must be CQLConfig"
    assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
    is_real_dataset = "real" in task_name
    assert not is_real_dataset, "we only ran BC-RNN on real robot"
    if not is_real_dataset:
        assert hdf5_type != "raw", "cannot train on raw demonstrations"

    with config.train.values_unlocked():
        # CQL uses batch size 1024 (for low-dim) and 8 (for image)
        if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
            config.train.batch_size = 1024
        else:
            config.train.batch_size = 8

    with config.algo.values_unlocked():
        # base parameters that may get modified further
        config.algo.optim_params.critic.learning_rate.initial = 1e-3                # learning rates
        config.algo.optim_params.actor.learning_rate.initial = 3e-4
        config.algo.actor.target_entropy = "default"                                # use automatic entropy tuning to default target value
        config.algo.critic.deterministic_backup = True                              # deterministic Q-backup
        config.algo.critic.target_q_gap = 5.0                                       # use Lagrange, with threshold 5.0
        config.algo.critic.min_q_weight = 1.0
        config.algo.target_tau = 5e-3                                               # tau 5e-3
        config.algo.discount = 0.99                                                 # discount 0.99
        config.algo.critic.layer_dims = (300, 400)                                  # all MLP sizes at (300, 400)
        config.algo.actor.layer_dims = (300, 400)

        if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
            # update policy LR to 1e-4 for image runs
            config.algo.optim_params.actor.learning_rate.initial = 1e-4

    return config


def modify_hbc_config_for_dataset(config, task_name, dataset_type, hdf5_type):
    """
    Modifies a HBCConfig object for training on a particular kind of dataset. This function
    just sets algorithm hyperparameters in the algo config depending on the kind of 
    dataset.

    Args:
        config (HBCConfig instance): config to modify

        task_name (str): identify task that dataset was collected on. Only used to distinguish
            between simulation and real-world, for an assert statement

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 
    """
    assert isinstance(config, HBCConfig), "must be HBCConfig"
    assert config.algo_name == "hbc", "must be HBCConfig"
    assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
    assert hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"], "HBC only runs on low-dim"
    is_real_dataset = "real" in task_name
    assert not is_real_dataset, "we only ran BC-RNN on real robot"

    with config.algo.values_unlocked():
        # base parameters that may get modified further
        config.algo.actor.optim_params.policy.learning_rate.initial = 1e-3          # learning rates
        config.algo.planner.optim_params.goal_network.learning_rate.initial = 1e-3

        config.algo.planner.vae.enabled = True                                      # goal VAE settings
        config.algo.planner.vae.kl_weight = 5e-4                                    # beta 5e-4
        config.algo.planner.vae.latent_dim = 16                                     # latent dim 16
        config.algo.planner.vae.prior.learn = True                                  # learn GMM prior with 10 modes
        config.algo.planner.vae.prior.is_conditioned = True
        config.algo.planner.vae.prior.use_gmm = True
        config.algo.planner.vae.prior.gmm_learn_weights = True
        config.algo.planner.vae.prior.gmm_num_modes = 10
        config.algo.planner.vae.encoder_layer_dims = (1024, 1024)                   # VAE network sizes
        config.algo.planner.vae.decoder_layer_dims = (1024, 1024)
        config.algo.planner.vae.prior_layer_dims = (1024, 1024)

        config.algo.actor.rnn.hidden_dim = 400                                      # actor RNN dim
        config.algo.actor.actor_layer_dims = ()                                     # no MLP layers between rnn layer and output

        if dataset_type == "mg":
            # update hyperparams for machine-generated datasets
            config.algo.actor.rnn.hidden_dim = 100
            config.algo.actor.actor_layer_dims = (1024, 1024)

    return config


def modify_iris_config_for_dataset(config, task_name, dataset_type, hdf5_type):
    """
    Modifies a IRISConfig object for training on a particular kind of dataset. This function
    just sets algorithm hyperparameters in the algo config depending on the kind of 
    dataset.

    Args:
        config (IRISConfig instance): config to modify

        task_name (str): identify task that dataset was collected on. Only used to distinguish
            between simulation and real-world, for an assert statement

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 
    """
    assert isinstance(config, IRISConfig), "must be IRISConfig"
    assert config.algo_name == "iris", "must be IRISConfig"
    assert dataset_type in ["ph", "mh", "mg", "paired"], "invalid dataset type"
    assert hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"], "IRIS only runs on low-dim"
    is_real_dataset = "real" in task_name
    assert not is_real_dataset, "we only ran BC-RNN on real robot"

    with config.algo.values_unlocked():
        # base parameters that may get modified further
        config.algo.actor.optim_params.policy.learning_rate.initial = 1e-3                          # learning rates
        config.algo.value_planner.planner.optim_params.goal_network.learning_rate.initial = 1e-3
        config.algo.value_planner.value.optim_params.critic.learning_rate.initial = 1e-3
        config.algo.value_planner.value.optim_params.action_sampler.learning_rate.initial = 1e-4

        config.algo.value_planner.planner.vae.enabled = True                                        # goal VAE settings
        config.algo.value_planner.planner.vae.kl_weight = 5e-4                                      # beta 5e-4
        config.algo.value_planner.planner.vae.latent_dim = 14                                       # latent dim 14
        config.algo.value_planner.planner.vae.prior.learn = True                                    # learn GMM prior with 10 modes
        config.algo.value_planner.planner.vae.prior.is_conditioned = True
        config.algo.value_planner.planner.vae.prior.use_gmm = True
        config.algo.value_planner.planner.vae.prior.gmm_learn_weights = True
        config.algo.value_planner.planner.vae.prior.gmm_num_modes = 10
        config.algo.value_planner.planner.vae.encoder_layer_dims = (1024, 1024)                     # VAE network sizes
        config.algo.value_planner.planner.vae.decoder_layer_dims = (1024, 1024)
        config.algo.value_planner.planner.vae.prior_layer_dims = (1024, 1024)

        config.algo.value_planner.value.target_tau = 5e-4                                           # Value tau
        config.algo.value_planner.value.action_sampler.vae.kl_weight = 0.5                          # Value KL
        config.algo.value_planner.value.action_sampler.vae.latent_dim = 16
        config.algo.value_planner.value.action_sampler.actor_layer_dims = (300, 400)

        config.algo.actor.rnn.hidden_dim = 400                                                      # actor RNN dim
        config.algo.actor.actor_layer_dims = ()                                                     # no MLP layers between rnn layer and output

        if dataset_type in ["mh", "paired"]:
            # value LR 1e-4, KL weight is 0.05 for multi-human datasets
            config.algo.value_planner.value.optim_params.critic.learning_rate.initial = 1e-4
            config.algo.value_planner.value.action_sampler.vae.kl_weight = 0.05

        if dataset_type in ["mg"]:
            # Enable value actor and set larger target tau
            config.algo.value_planner.value.actor.enabled = True
            config.algo.value_planner.value.optim_params.actor.learning_rate.initial = 1e-3
            config.algo.value_planner.value.target_tau = 5e-3

    return config


def generate_experiment_config(
    base_exp_name, 
    base_config_dir, 
    base_dataset_dir, 
    base_output_dir, 
    algo_name, 
    algo_config_modifier, 
    task_name, 
    dataset_type, 
    hdf5_type,
    filter_key=None,
    additional_name=None,
    additional_config_modifier=None,
):
    """
    Helper function to generate a config for a particular experiment.

    Args:
        base_exp_name (str): name that identifies this set of experiments

        base_config_dir (str): base directory to place generated configs

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        base_output_dir (str): directory to save training results to. If None, will use the directory
            from the default algorithm configs.

        algo_name (str): identifies the algorithm - one of ["bc", "bc_rnn", "bcq", "cql", hbc", "iris"]

        algo_config_modifier (function): function to modify config to add algo hyperparameter
            settings, given the task, dataset, and hdf5 types.

        task_name (str): identify task that dataset was collected on. Only used to distinguish
            between simulation and real-world, for an assert statement

        dataset_type (str): dataset type for this dataset (e.g. ph, mh, mg, paired).

        hdf5_type (str): hdf5 type for this dataset (e.g. raw, low_dim, image). 

        filter_key (str): if not None, use the provided filter key to select a subset of the
            provided dataset

        additional_name (str): if provided, will add this name to the generated experiment name, and
            the name of the generated config json

        additional_config_modifier (function): if provided, run this last function on the config
            to make final modifications before generating the json.
    """
    if "real" not in task_name:
        assert hdf5_type != "raw", "cannot train on raw demonstrations"

    # decide whether to use low-dim or image training defaults
    modifier_for_obs = modify_config_for_default_image_exp
    if hdf5_type in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
        modifier_for_obs = modify_config_for_default_low_dim_exp

    algo_config_name = "bc" if algo_name == "bc_rnn" else algo_name
    config = config_factory(algo_name=algo_config_name)
    # turn into default config for observation modalities (e.g.: low-dim or rgb)
    config = modifier_for_obs(config)
    # add in config based on the dataset
    config = modify_config_for_dataset(
        config=config, 
        task_name=task_name, 
        dataset_type=dataset_type, 
        hdf5_type=hdf5_type, 
        base_dataset_dir=base_dataset_dir,
        filter_key=filter_key,
    )
    # add in algo hypers based on dataset
    config = algo_config_modifier(
        config=config, 
        task_name=task_name, 
        dataset_type=dataset_type, 
        hdf5_type=hdf5_type,
    )
    if additional_config_modifier is not None:
        # use additional config modifier if provided
        config = additional_config_modifier(config)

    # account for filter key in experiment naming and directory naming
    filter_key_str = "_{}".format(filter_key) if filter_key is not None else ""
    dataset_type_dir = "{}/{}".format(dataset_type, filter_key) if filter_key is not None else dataset_type

    # account for @additional_name
    additional_name_str = "_{}".format(additional_name) if additional_name is not None else ""
    json_name = "{}{}".format(algo_name, additional_name_str)

    # set experiment name
    with config.experiment.values_unlocked():
        config.experiment.name = "{}_{}_{}_{}{}_{}{}".format(base_exp_name, algo_name, task_name, dataset_type, filter_key_str, hdf5_type, additional_name_str)
    # set output folder
    with config.train.values_unlocked():
        if base_output_dir is None:
            base_output_dir = config.train.output_dir
        config.train.output_dir = os.path.join(base_output_dir, base_exp_name, algo_name, task_name, dataset_type_dir, hdf5_type, "trained_models")
    
    # save config to json file
    dir_to_save = os.path.join(base_config_dir, base_exp_name, task_name, dataset_type_dir, hdf5_type)
    os.makedirs(dir_to_save, exist_ok=True)
    json_path = os.path.join(dir_to_save, "{}.json".format(json_name))
    config.dump(filename=json_path)

    return config, json_path


def generate_core_configs(
    base_config_dir, 
    base_dataset_dir, 
    base_output_dir, 
    algo_to_config_modifier, 
):
    """
    Helper function to generate all configs for core set of experiments.

    Args:
        base_config_dir (str): base directory to place generated configs

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        base_output_dir (str): directory to save training results to. If None, will use the directory
            from the default algorithm configs.

        algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs 
            to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
    """
    core_json_paths = Config() # use for convenient nested dict
    for task in DATASET_REGISTRY:
        for dataset_type in DATASET_REGISTRY[task]:
            for hdf5_type in DATASET_REGISTRY[task][dataset_type]:
                # if not real robot dataset, skip raw hdf5
                is_real_dataset = ("real" in task)
                if not is_real_dataset and hdf5_type == "raw":
                    continue
                
                # get list of algorithms to generate configs for, for this hdf5 dataset
                algos_to_generate = ["bc", "bc_rnn", "bcq", "cql", "hbc", "iris"]
                if hdf5_type not in ["low_dim", "low_dim_sparse", "low_dim_dense"]:
                    # no hbc or iris for image runs
                    algos_to_generate = algos_to_generate[:-2]
                if is_real_dataset:
                    # we only ran BC-RNN on real robot
                    algos_to_generate = ["bc_rnn"]

                for algo_name in algos_to_generate:

                    # generate config for this experiment
                    config, json_path = generate_experiment_config(
                        base_exp_name="core",
                        base_config_dir=base_config_dir,
                        base_dataset_dir=base_dataset_dir,
                        base_output_dir=base_output_dir,
                        algo_name=algo_name, 
                        algo_config_modifier=algo_to_config_modifier[algo_name], 
                        task_name=task, 
                        dataset_type=dataset_type, 
                        hdf5_type=hdf5_type,
                    )

                    # save json path into dict
                    core_json_paths[task][dataset_type][hdf5_type][algo_name] = json_path

    return core_json_paths


def generate_subopt_configs(
    base_config_dir, 
    base_dataset_dir, 
    base_output_dir, 
    algo_to_config_modifier, 
):
    """
    Helper function to generate all configs for the suboptimal human subsets of the multi-human datasets.
    Note that while the paper includes the results on the can-paired dataset along with results on these
    datasets, the configs for runs on the can-paired dataset is in the "core" set of runs.

    Args:
        base_config_dir (str): base directory to place generated configs

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        base_output_dir (str): directory to save training results to. If None, will use the directory
            from the default algorithm configs.

        algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs 
            to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
    """
    subopt_json_paths = Config() # use for convenient nested dict
    for task in ["lift", "can", "square", "transport"]:
        # only generate configs for multi-human data subsets
        for dataset_type in ["mh"]:
            # only low-dim / image
            for hdf5_type in ["low_dim", "image"]:

                # get list of algorithms to generate configs for, for this hdf5 dataset
                algos_to_generate = ["bc", "bc_rnn", "bcq", "cql", "hbc", "iris"]
                if hdf5_type == "image":
                    # no hbc or iris for image runs
                    algos_to_generate = algos_to_generate[:-2]

                for algo_name in algos_to_generate:

                    for fk in ["worse", "okay", "better", "worse_okay", "worse_better", "okay_better"]:

                        # generate config for this experiment
                        config, json_path = generate_experiment_config(
                            base_exp_name="subopt",
                            base_config_dir=base_config_dir,
                            base_dataset_dir=base_dataset_dir,
                            base_output_dir=base_output_dir,
                            algo_name=algo_name, 
                            algo_config_modifier=algo_to_config_modifier[algo_name], 
                            task_name=task, 
                            dataset_type=dataset_type, 
                            hdf5_type=hdf5_type,
                            filter_key=fk,
                        )

                        # save json path into dict
                        dataset_type_dir = "{}/{}".format(dataset_type, fk)
                        subopt_json_paths[task][dataset_type_dir][hdf5_type][algo_name] = json_path

    return subopt_json_paths


def generate_dataset_size_configs(
    base_config_dir, 
    base_dataset_dir, 
    base_output_dir, 
    algo_to_config_modifier, 
):
    """
    Helper function to generate all configs for the dataset size ablation experiments, where BC-RNN models
    were trained on 20% and 50% dataset sizes.

    Args:
        base_config_dir (str): base directory to place generated configs

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        base_output_dir (str): directory to save training results to. If None, will use the directory
            from the default algorithm configs.

        algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs 
            to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
    """
    size_ablation_json_paths = Config() # use for convenient nested dict
    for task in ["lift", "can", "square", "transport"]:
        for dataset_type in ["ph", "mh"]:
            for hdf5_type in ["low_dim", "image"]:

                # only bc-rnn
                algo_name = "bc_rnn"
                for fk in ["20_percent", "50_percent"]:

                    # generate config for this experiment
                    config, json_path = generate_experiment_config(
                        base_exp_name="dataset_size",
                        base_config_dir=base_config_dir,
                        base_dataset_dir=base_dataset_dir,
                        base_output_dir=base_output_dir,
                        algo_name=algo_name, 
                        algo_config_modifier=algo_to_config_modifier[algo_name], 
                        task_name=task, 
                        dataset_type=dataset_type, 
                        hdf5_type=hdf5_type,
                        filter_key=fk,
                    )

                    # save json path into dict
                    dataset_type_dir = "{}/{}".format(dataset_type, fk)
                    size_ablation_json_paths[task][dataset_type_dir][hdf5_type][algo_name] = json_path

    return size_ablation_json_paths


def generate_obs_ablation_configs(
    base_config_dir, 
    base_dataset_dir, 
    base_output_dir, 
    algo_to_config_modifier, 
):
    """
    Helper function to generate all configs for the observation ablation experiments, where BC and BC-RNN models
    were trained on different versions of low-dim and image observations.

    Args:
        base_config_dir (str): base directory to place generated configs

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        base_output_dir (str): directory to save training results to. If None, will use the directory
            from the default algorithm configs.

        algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs 
            to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
    """

    # observation config modifiers for these experiments
    def add_eef_vel(config):
        with config.observation.values_unlocked():
            old_low_dim_mods = list(config.observation.modalities.obs.low_dim)
            old_low_dim_mods.extend(["robot0_eef_vel_lin", "robot0_eef_vel_ang", "robot0_gripper_qvel"])
            if "robot1_eef_pos" in old_low_dim_mods:
                old_low_dim_mods.extend(["robot1_eef_vel_lin", "robot1_eef_vel_ang", "robot1_gripper_qvel"])
            config.observation.modalities.obs.low_dim = old_low_dim_mods
        return config

    def add_proprio(config):
        with config.observation.values_unlocked():
            old_low_dim_mods = list(config.observation.modalities.obs.low_dim)
            old_low_dim_mods.extend(["robot0_joint_pos_cos", "robot0_joint_pos_sin", "robot0_joint_vel"])
            if "robot1_eef_pos" in old_low_dim_mods:
                old_low_dim_mods.extend(["robot1_joint_pos_cos", "robot1_joint_pos_sin", "robot1_joint_vel"])
            config.observation.modalities.obs.low_dim = old_low_dim_mods
        return config

    def remove_wrist(config):
        with config.observation.values_unlocked():
            old_image_mods = list(config.observation.modalities.obs.rgb)
            config.observation.modalities.obs.rgb = [m for m in old_image_mods if "eye_in_hand" not in m]
        return config

    def remove_rand(config):
        with config.observation.values_unlocked():
            config.observation.encoder.rgb.obs_randomizer_class = None
        return config

    obs_ablation_json_paths = Config() # use for convenient nested dict
    for task in ["square", "transport"]:
        for dataset_type in ["ph", "mh"]:
            for hdf5_type in ["low_dim", "image"]:

                # observation modifiers to apply
                if hdf5_type == "low_dim":
                    obs_modifiers = [add_eef_vel, add_proprio]
                else:
                    obs_modifiers = [add_eef_vel, add_proprio, remove_wrist, remove_rand]

                # only bc and bc-rnn
                algos_to_generate = ["bc", "bc_rnn"]
                for algo_name in algos_to_generate:
                    for obs_modifier in obs_modifiers:
                        # generate config for this experiment
                        config, json_path = generate_experiment_config(
                            base_exp_name="obs_ablation",
                            base_config_dir=base_config_dir,
                            base_dataset_dir=base_dataset_dir,
                            base_output_dir=base_output_dir,
                            algo_name=algo_name, 
                            algo_config_modifier=algo_to_config_modifier[algo_name], 
                            task_name=task, 
                            dataset_type=dataset_type, 
                            hdf5_type=hdf5_type,
                            additional_name=obs_modifier.__name__,
                            additional_config_modifier=obs_modifier,
                        )

                        # save json path into dict
                        algo_name_str = "{}_{}".format(algo_name, obs_modifier.__name__)
                        obs_ablation_json_paths[task][dataset_type][hdf5_type][algo_name_str] = json_path

    return obs_ablation_json_paths


def generate_hyper_ablation_configs(
    base_config_dir, 
    base_dataset_dir, 
    base_output_dir, 
    algo_to_config_modifier, 
):
    """
    Helper function to generate all configs for the hyperparameter sensitivity experiments, 
    where BC-RNN models were trained on different ablations.

    Args:
        base_config_dir (str): base directory to place generated configs

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        base_output_dir (str): directory to save training results to. If None, will use the directory
            from the default algorithm configs.

        algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs 
            to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
    """

    # observation config modifiers for these experiments
    def change_lr(config):
        with config.algo.values_unlocked():
            config.algo.optim_params.policy.learning_rate.initial = 1e-3
        return config

    def change_gmm(config):
        with config.algo.values_unlocked():
            config.algo.gmm.enabled = False
        return config

    def change_mlp(config):
        with config.algo.values_unlocked():
            config.algo.actor_layer_dims = (1024, 1024) 
        return config

    def change_conv(config):
        with config.observation.values_unlocked():
            config.observation.encoder.rgb.core_class = 'ShallowConv'
            config.observation.encoder.rgb.core_kwargs = Config()
        return config

    def change_rnnd_low_dim(config):
        with config.algo.values_unlocked():
            config.algo.rnn.hidden_dim = 100
        return config

    def change_rnnd_image(config):
        with config.algo.values_unlocked():
            config.algo.rnn.hidden_dim = 400
        return config

    hyper_ablation_json_paths = Config() # use for convenient nested dict
    for task in ["square", "transport"]:
        for dataset_type in ["ph", "mh"]:
            for hdf5_type in ["low_dim", "image"]:

                # observation modifiers to apply
                if hdf5_type == "low_dim":
                    hyper_modifiers = [change_lr, change_gmm, change_mlp, change_rnnd_low_dim]
                else:
                    hyper_modifiers = [change_lr, change_gmm, change_conv, change_rnnd_image]

                # only bc and bc-rnn
                algo_name = "bc_rnn"
                for hyper_modifier in hyper_modifiers:
                    # generate config for this experiment
                    config, json_path = generate_experiment_config(
                        base_exp_name="hyper_ablation",
                        base_config_dir=base_config_dir,
                        base_dataset_dir=base_dataset_dir,
                        base_output_dir=base_output_dir,
                        algo_name=algo_name, 
                        algo_config_modifier=algo_to_config_modifier[algo_name], 
                        task_name=task, 
                        dataset_type=dataset_type, 
                        hdf5_type=hdf5_type,
                        additional_name=hyper_modifier.__name__,
                        additional_config_modifier=hyper_modifier,
                    )

                    # save json path into dict
                    algo_name_str = "{}_{}".format(algo_name, hyper_modifier.__name__)
                    hyper_ablation_json_paths[task][dataset_type][hdf5_type][algo_name_str] = json_path

    return hyper_ablation_json_paths


def generate_d4rl_configs(
    base_config_dir, 
    base_dataset_dir, 
    base_output_dir, 
    algo_to_config_modifier, 
):
    """
    Helper function to generate all configs for reproducing BCQ, CQL, and TD3-BC runs on some D4RL
    environments.

    Args:
        base_config_dir (str): base directory to place generated configs

        base_dataset_dir (str): path to directory where datasets are on disk.
            Directory structure is expected to be consistent with the output
            of @make_dataset_dirs in the download_datasets.py script.

        base_output_dir (str): directory to save training results to. If None, will use the directory
            from the default algorithm configs.

        algo_to_config_modifier (dict): dictionary that maps algo name to a function that modifies configs 
            to add algo hyperparameter settings, given the task, dataset, and hdf5 types.
    """

    def bcq_algo_config_modifier(config):
        with config.algo.values_unlocked():
            # all LRs 1e-3, enable actor
            config.algo.optim_params.critic.learning_rate.initial = 1e-3
            config.algo.optim_params.action_sampler.learning_rate.initial = 1e-3
            config.algo.optim_params.actor.learning_rate.initial = 1e-3
            config.algo.actor.enabled = True
            config.algo.action_sampler.vae.kl_weight = 0.5
        return config

    def cql_algo_config_modifier(config):
        with config.algo.values_unlocked():
            # taken from TD3-BC settings described in their paper
            config.algo.optim_params.critic.learning_rate.initial = 3e-4
            config.algo.optim_params.actor.learning_rate.initial = 3e-5
            config.algo.actor.bc_start_steps = 40000                        # pre-training steps for actor
            config.algo.critic.target_q_gap = None                          # no Lagrange, and fixed weight of 10.0
            config.algo.critic.cql_weight = 10.0
            config.algo.critic.min_q_weight = 1.0 
            config.algo.critic.deterministic_backup = True                  # deterministic backup (no entropy in Q-target)
            config.algo.actor.layer_dims = (256, 256, 256)                  # MLP sizes
            config.algo.critic.layer_dims = (256, 256, 256)
        return config
    
    def iql_algo_config_modifier(config):
        with config.algo.values_unlocked():
            # taken from IQL settings described in their paper
            config.algo.target_tau = 0.005
            config.algo.vf_quantile = 0.7
            config.algo.adv.beta = 3.0
            config.algo.optim_params.critic.learning_rate.initial = 3e-4
            config.algo.optim_params.vf.learning_rate.initial = 3e-4
            config.algo.optim_params.actor.learning_rate.initial = 3e-4
            config.algo.actor.layer_dims = (256, 256, 256)                  # MLP sizes
            config.algo.critic.layer_dims = (256, 256, 256)
        return config

    d4rl_tasks = [
        # "halfcheetah-random-v2",
        # "hopper-random-v2",
        # "walker2d-random-v2",
        "halfcheetah-medium-v2",
        "hopper-medium-v2",
        "walker2d-medium-v2",
        "halfcheetah-expert-v2",
        "hopper-expert-v2",
        "walker2d-expert-v2",
        "halfcheetah-medium-expert-v2",
        "hopper-medium-expert-v2",
        "walker2d-medium-expert-v2",
        # "halfcheetah-medium-replay-v2",
        # "hopper-medium-replay-v2",
        # "walker2d-medium-replay-v2",
    ]
    d4rl_json_paths = Config() # use for convenient nested dict
    for task_name in d4rl_tasks:
        for algo_name in ["bcq", "cql", "td3_bc", "iql"]:
            config = config_factory(algo_name=algo_name)

            # hack: copy experiment and train sections from td3-bc, since that has defaults for training with D4RL
            if algo_name != "td3_bc":
                ref_config = config_factory(algo_name="td3_bc")
                with config.values_unlocked():
                    config.experiment = ref_config.experiment
                    config.train = ref_config.train
                    config.observation = ref_config.observation
                    config.train.hdf5_normalize_obs = False # only TD3-BC uses observation normalization

            # modify algo section for d4rl defaults
            if algo_name == "bcq":
                config = bcq_algo_config_modifier(config)
            elif algo_name == "cql":
                config = cql_algo_config_modifier(config)
            elif algo_name == "iql":
                config = iql_algo_config_modifier(config)

            # set experiment name
            with config.experiment.values_unlocked():
                config.experiment.name = "{}_{}_{}".format("d4rl", algo_name, task_name)
            # set output folder and dataset
            with config.train.values_unlocked():
                if base_output_dir is None:
                    base_output_dir_for_algo = "../{}_trained_models".format(algo_name)
                else:
                    base_output_dir_for_algo = base_output_dir
                config.train.output_dir = os.path.join(base_output_dir_for_algo, "d4rl", algo_name, task_name, "trained_models")
                config.train.data = os.path.join(base_dataset_dir, "d4rl", "converted", 
                    "{}.hdf5".format(task_name.replace("-", "_")))

            # save config to json file
            dir_to_save = os.path.join(base_config_dir, "d4rl", task_name)
            os.makedirs(dir_to_save, exist_ok=True)
            json_path = os.path.join(dir_to_save, "{}.json".format(algo_name))
            config.dump(filename=json_path)

            # save json path into dict
            d4rl_json_paths[task_name][""][""][algo_name] = json_path

    return d4rl_json_paths


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    # Directory where generated configs will be placed
    parser.add_argument(
        "--config_dir",
        type=str,
        default=None,
        help="Directory where generated configs will be placed. Defaults to 'paper' subfolder in exps folder of repository",
    )

    # directory where released datasets are located
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default=None,
        help="Base dataset directory for released datasets. Defaults to datasets folder in repository.",
    )

    # output directory for training runs (will be written to configs)
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Base output directory for all training runs that will be written to generated configs.",
    )

    args = parser.parse_args()

    # read args
    generated_configs_base_dir = args.config_dir
    if generated_configs_base_dir is None:
        generated_configs_base_dir = os.path.join(robomimic.__path__[0], "exps/paper")

    datasets_base_dir = args.dataset_dir
    if datasets_base_dir is None:
        datasets_base_dir = os.path.join(robomimic.__path__[0], "../datasets")

    output_base_dir = args.output_dir

    # algo to modifier
    algo_to_modifier = dict(
        bc=modify_bc_config_for_dataset, 
        bc_rnn=modify_bc_rnn_config_for_dataset,
        bcq=modify_bcq_config_for_dataset,
        cql=modify_cql_config_for_dataset,
        hbc=modify_hbc_config_for_dataset,
        iris=modify_iris_config_for_dataset,
    )

    # exp name to config generator
    exp_name_to_generator = dict(
        core=generate_core_configs,
        subopt=generate_subopt_configs,
        dataset_size=generate_dataset_size_configs,
        obs_ablation=generate_obs_ablation_configs,
        hyper_ablation=generate_hyper_ablation_configs,
        d4rl=generate_d4rl_configs,
    )

    # generate configs for each experiment name
    config_json_paths = Config() # use for convenient nested dict
    for exp_name in exp_name_to_generator:
        config_json_paths[exp_name] = exp_name_to_generator[exp_name](
            base_config_dir=generated_configs_base_dir, 
            base_dataset_dir=datasets_base_dir, 
            base_output_dir=output_base_dir, 
            algo_to_config_modifier=algo_to_modifier, 
        )

    # write output shell scripts
    for exp_name in config_json_paths:
        shell_path = os.path.join(generated_configs_base_dir, "{}.sh".format(exp_name))
        with open(shell_path, "w") as f:
            f.write("#!/bin/bash\n\n")
            f.write("# " + "=" * 10 + exp_name + "=" * 10 + "\n")
            train_script_loc = os.path.join(robomimic.__path__[0], "scripts/train.py")

            for task in config_json_paths[exp_name]:
                for dataset_type in config_json_paths[exp_name][task]:
                    for hdf5_type in config_json_paths[exp_name][task][dataset_type]:
                        f.write("\n")
                        f.write("#  task: {}\n".format(task))
                        if len(dataset_type) > 0:
                            f.write("#    dataset type: {}\n".format(dataset_type))
                        if len(hdf5_type) > 0:
                            f.write("#      hdf5 type: {}\n".format(hdf5_type))
                        for algo_name in config_json_paths[exp_name][task][dataset_type][hdf5_type]:
                            # f.write("#        {}\n".format(algo_name))
                            exp_json_path = config_json_paths[exp_name][task][dataset_type][hdf5_type][algo_name]
                            cmd = "python {} --config {}\n".format(train_script_loc, exp_json_path)
                            f.write(cmd)
            f.write("\n")
