""" Define parameters for algorithms. """

import argparse


def str2bool(v):
    return v.lower() == "true"


def str2intlist(value):
    if not value:
        return value
    else:
        return [int(num) for num in value.split(",")]


def str2list(value):
    if not value:
        return value
    else:
        return [num for num in value.split(",")]


def create_parser():
    """
    Creates the argparser.  Use this to add additional arguments
    to the parser later.
    """
    parser = argparse.ArgumentParser(
        "Robot Learning Algorithms",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # environment
    parser.add_argument(
        "--env",
        type=str,
        default="maze2d-open-yellow-v0",
        help="environment name",
    )
    parser.add_argument("--env_type", type=str, default=None, choices=["general","close1","close2","faraway","task8","task12"])
    parser.add_argument("--seed", type=int, default=123)
    parser.add_argument("--num_envs", type=int, default=1)

    add_method_arguments(parser)

    return parser


def add_method_arguments(parser):
    # algorithm
    parser.add_argument(
        "--algo",
        type=str,
        default="gail-v2",
        choices=[
            "sac",
            "ppo",
            "ddpg",
            "td3",
            "dqn",
            "bc",
            "mt-bc",
            "gail",
            "sqil",
            "airl",
            "acgail",
            "dac",
            "gail-v1",
            "gail-v2",
            "reachable_gail",
            "reachable_gail-v0",
            "iqlearn",
            "prox"
        ],
    )

    # training
    parser.add_argument("--is_train", type=str2bool, default=True)
    parser.add_argument("--resume", type=str2bool, default=True)
    parser.add_argument("--init_ckpt_path", type=str, default=None)
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument(
        "--other_demo_path",
        type=str,
        default=None,
        help="path to other task demos used for "
        "pretraining discriminator of gail or initializing unreach buffer of reachable_gail-v0",
    )

    # pre-training policy
    parser.add_argument("--with_taskID", type=str2bool, default=False)
    # parser.add_argument("--dis_with_taskID", type=str2bool, default=False)
    parser.add_argument(
        "--num_task", type=int, default=4, help="length of the one-hot encoder"
    )  # change for different envs
    parser.add_argument("--target_taskID", type=str2intlist, default=None)
    # parser.add_argument("--ft_option", type=str, default="entire_network", choices=["entire_network",
    # "output_layer_only"])

    # evaluation
    parser.add_argument("--ckpt_num", type=int, default=None)
    parser.add_argument(
        "--num_eval", type=int, default=10, help="number of episodes for evaluation"
    )

    # environment
    try:
        parser.add_argument("--screen_width", type=int, default=480)
        parser.add_argument("--screen_height", type=int, default=480)
        parser.add_argument(
            "--is_metaworld", type=str2bool, default=False
        )  # change for different envs
        parser.add_argument("--is_2factor", type=str2bool, default=False)
    except:
        pass
    parser.add_argument("--action_repeat", type=int, default=1)


    parser.add_argument(
        "--bc_loss_coeff", type=float, default=1.0, help="blend bc loss with rl loss"
    )
    parser.add_argument(
        "--bc_step", type=int, default=50000, help="before bc_step, bc_loss_coeff * 10"
    )
    parser.add_argument("--evaluate_bc_test_loss", type=str2bool, default=False)

    # misc
    parser.add_argument("--run_prefix", type=str, default=None)
    parser.add_argument("--notes", type=str, default="")

    # log
    parser.add_argument("--average_info", type=str2bool, default=True)
    parser.add_argument("--log_interval", type=int, default=1)
    parser.add_argument("--evaluate_interval", type=int, default=10)
    parser.add_argument("--ckpt_interval", type=int, default=100)
    parser.add_argument(
        "--ood_traj_interval",
        type=int,
        default=20,
        help="interval for saving ood trajectories, value should be larger than evaluate_interval",
    )
    parser.add_argument("--log_root_dir", type=str, default="log")
    parser.add_argument(
        "--wandb",
        type=str2bool,
        default=True,
        help="set it True if you want to use wandb",
    )
    parser.add_argument("--wandb_entity", type=str, default=None)
    parser.add_argument("--wandb_project", type=str, default=None)
    parser.add_argument("--record_video", type=str2bool, default=False)
    parser.add_argument("--record_ood_video", type=str2bool, default=False)
    parser.add_argument("--record_video_caption", type=str2bool, default=True)
    parser.add_argument(
        "--record_video_start_step",
        type=int,
        default=0,
        help="start step for recording video in each episode",
    )
    parser.add_argument(
        "--record_video_every_step",
        type=int,
        default=4,
        help="record video every n steps",
    )
    # parser.add_argument("--is_save_ood_traj", type=str2bool, default=False)

    try:
        parser.add_argument("--record_demo", type=str2bool, default=False)
    except:
        pass

    # observation normalization
    parser.add_argument("--ob_norm", type=str2bool, default=True)
    parser.add_argument("--max_ob_norm_step", type=int, default=int(1e8))
    parser.add_argument(
        "--clip_obs", type=float, default=200, help="the clip range of observation"
    )
    parser.add_argument(
        "--clip_range",
        type=float,
        default=10,
        help="the clip range after normalization of observation",
    )

    parser.add_argument("--max_global_step", type=int, default=int(7e6))
    parser.add_argument(
        "--batch_size", type=int, default=128, help="the sample batch size"
    )

    add_policy_arguments(parser)

    # arguments specific to algorithms
    args, unparsed = parser.parse_known_args()

    if args.algo == "sac":
        add_sac_arguments(parser)

    elif args.algo == "dqn":
        add_dqn_arguments(parser)

    # elif args.algo == "ddpg":
    #     add_ddpg_arguments(parser)

    # elif args.algo == "td3":
    #     add_td3_arguments(parser)

    elif args.algo == "ppo":
        add_ppo_arguments(parser)

    elif args.algo == "bc":
        add_il_arguments(parser)
        add_bc_arguments(parser)

    elif args.algo == "mt-bc":
        add_il_arguments(parser)
        add_bc_arguments(parser)
        ### TODO: add demo path args

    elif args.algo in ["gail", "gaifo", "gaifo-s"]:
        add_il_arguments(parser)
        add_gail_arguments(parser)

    # elif args.algo == "airl":
    #     add_il_arguments(parser)
    #     add_airl_arguments(parser)

    # elif args.algo == "acgail":
    #     add_il_arguments(parser)
    #     add_acgail_arguments(parser)

    elif args.algo in ["gail-v2"]:
        add_il_arguments(parser)
        add_gail_v1_arguments(parser)

    elif args.algo == "reachable_gail":  # our old version: dvd (gail-v2) + reach reward
        add_il_arguments(parser)
        add_gail_v1_arguments(parser)
        add_reachable_arguments(parser)

    elif args.algo == "reachable_gail-v0":  # original gail + reach reward
        add_il_arguments(parser)
        add_gail_arguments(parser)
        add_reachable_arguments(parser)

    elif args.algo == "prox":
        add_il_arguments(parser)
        parser.add_argument(
        "--gail_rl_algo", type=str, default="ppo", choices=["ppo", "sac", "td3"]
        )
        parser.add_argument("--gail_entropy_loss_coeff", type=float, default=0.0)
        add_discriminator_arguments(parser)
        add_reachable_arguments(parser)

    elif args.algo == "sqil":
        add_il_arguments(parser)
        # add_off_policy_arguments(parser)
        add_sac_arguments(parser)
        add_sqil_arguments(parser)

    # elif args.algo in ["dac"]:
    #     add_il_arguments(parser)
    #     add_dac_arguments(parser)

    # elif args.algo == "iqlearn":
    #     add_il_arguments(parser)
    #     add_sac_arguments(parser)
    #     add_iqlearn_arguments(parser)

    if args.algo in [
        "gail",
        "gaifo",
        "gaifo-s",
        "dac",
        "airl",
        "acgail",
        "gail-v1",
        "gail-v2",
        "reachable_gail",
        "reachable_gail-v0",
        "iq_learn",
        "prox",
    ]:
        args, unparsed = parser.parse_known_args()

        if args.gail_rl_algo == "ppo":
            add_ppo_arguments(parser)

        elif args.gail_rl_algo == "sac":
            add_sac_arguments(parser)

        elif args.gail_rl_algo == "td3":
            add_td3_arguments(parser)

        elif args.gail_rl_algo == "ddpg":
            add_ddpg_arguments(parser)

        elif args.gail_rl_algo == "dqn":
            add_ddpg_arguments(parser)

    return parser


def add_policy_arguments(parser):
    # network
    parser.add_argument(
        "--policy_mlp_dim", type=str2intlist, default=[256, 256, 256, 256]
    )
    parser.add_argument("--critic_mlp_dim", type=str2intlist, default=[256, 256])
    parser.add_argument("--critic_ensemble", type=int, default=1)
    parser.add_argument(
        "--policy_activation", type=str, default="relu", choices=["relu", "elu", "tanh"]
    )
    parser.add_argument("--tanh_policy", type=str2bool, default=True)
    parser.add_argument("--gaussian_policy", type=str2bool, default=True)
    parser.add_argument("--demo_conditioned_policy", type=str2bool, default=False)

    # encoder
    parser.add_argument(
        "--encoder_type", type=str, default="mlp", choices=["mlp", "cnn"]
    )
    parser.add_argument("--encoder_image_size", type=int, default=85) # image size need to be odd for size match when doing encode-decode
    parser.add_argument("--random_crop", type=str2bool, default=False)
    parser.add_argument("--encoder_conv_dim", type=int, default=32)
    parser.add_argument("--encoder_kernel_size", type=str2intlist, default=[3, 3, 3, 3])
    parser.add_argument("--encoder_stride", type=str2intlist, default=[2, 1, 1, 1])
    parser.add_argument("--encoder_conv_output_dim", type=int, default=1000, choices=[50, 1000])
    parser.add_argument("--encoder_soft_update_weight", type=float, default=0.95)
    parser.add_argument("--pretrained_encoder", type=str, default="none", choices=["none", "resnet", "vae", "r3m"])
    # parser.add_argument("--is_ft_encoder", type=str2bool, default=True)
    parser.add_argument("--encoder_pretrain_steps", type=int, default=5000)
    parser.add_argument(
        "--encoder_lr", type=float, default=1e-4, help="learning rate for bc"
    )
    args, unparsed = parser.parse_known_args()
    if args.encoder_type == "cnn":
        parser.set_defaults(screen_width=120, screen_height=120)
        parser.set_defaults(policy_mlp_dim=[1024, 1024])
        parser.set_defaults(critic_mlp_dim=[1024, 1024])
        parser.add_argument("--asym_ac", type=str2bool, default=False)
        parser.add_argument("--frame_stack", type=int, default=3)
        parser.add_argument("--img_augment", type=str2bool, default=True)
        parser.set_defaults(batch_size=128)
        parser.set_defaults(ob_norm=False)

    if args.demo_conditioned_policy:
        parser.add_argument("--lstm_hidden_dim", type=int, default=128)
        parser.add_argument("--lstm_num_layers", type=int, default=2)
        parser.add_argument("--lstm_output_dim", type=int, default=128)
        parser.add_argument("--traj_length", type=int, default=50)
        parser.add_argument(
            "--traj_interval",
            type=int,
            default=1,
            help="interval for sampling H step trajectories",
        )

    # actor-critic
    parser.add_argument(
        "--actor_lr", type=float, default=3e-4, help="the learning rate of the actor"
    )
    parser.add_argument(
        "--critic_lr", type=float, default=3e-4, help="the learning rate of the critic"
    )
    parser.add_argument(
        "--critic_soft_update_weight",
        type=float,
        default=0.995,
        help="the average coefficient",
    )

    parser.add_argument("--log_std_min", type=float, default=-10)
    parser.add_argument("--log_std_max", type=float, default=2)

    # absorbing state
    parser.add_argument("--absorbing_state", type=str2bool, default=False)


def add_rl_arguments(parser):
    parser.add_argument(
        "--rl_discount_factor", type=float, default=0.99, help="the discount factor"
    )
    parser.add_argument("--warm_up_steps", type=int, default=0)


def add_on_policy_arguments(parser):
    parser.add_argument("--rollout_length", type=int, default=2000)
    parser.add_argument("--gae_lambda", type=float, default=0.95)
    parser.add_argument("--advantage_norm", type=str2bool, default=True)


def add_off_policy_arguments(parser):
    parser.add_argument(
        "--buffer_size", type=int, default=int(1e6), help="the size of the buffer"
    )
    parser.set_defaults(warm_up_steps=1000)
    parser.add_argument("--num_env_steps_per_update", type=int, default=1)
    parser.add_argument("--num_actor_updates", type=int, default=1)
    parser.add_argument(
        "--is_relabel_rew",
        type=str2bool,
        default=True,
        help="whether to use relabel reward for training agent",
    )


def add_sac_arguments(parser):
    add_rl_arguments(parser)
    add_off_policy_arguments(parser)

    parser.add_argument("--reward_scale", type=float, default=1.0, help="reward scale")
    parser.add_argument("--critic_target_update_freq", type=int, default=2)
    parser.add_argument("--target_entropy", type=float, default=None)
    parser.add_argument("--alpha_init_temperature", type=float, default=0.1)
    parser.add_argument(
        "--alpha_lr", type=float, default=1e-4, help="the learning rate of the actor"
    )
    parser.set_defaults(actor_lr=3e-4)
    parser.set_defaults(critic_lr=3e-4)
    parser.set_defaults(evaluate_interval=40000)
    parser.set_defaults(ckpt_interval=10000)
    parser.set_defaults(log_interval=500)
    parser.set_defaults(critic_soft_update_weight=0.99)
    parser.set_defaults(buffer_size=100000)
    parser.set_defaults(critic_ensemble=2)
    parser.set_defaults(ob_norm=True)
    # parser.add_argument("--is_relabel_rew", type=str2bool, default=True, help="whether to use relabel reward for training as GAIL's agent")


def add_ppo_arguments(parser):
    add_rl_arguments(parser)
    add_on_policy_arguments(parser)

    parser.add_argument("--ppo_clip", type=float, default=0.2)
    parser.add_argument("--value_loss_coeff", type=float, default=0.5)
    parser.add_argument("--action_loss_coeff", type=float, default=1.0)
    parser.add_argument("--entropy_loss_coeff", type=float, default=1e-4)

    parser.add_argument("--ppo_epoch", type=int, default=5)
    parser.add_argument("--max_grad_norm", type=float, default=None)
    parser.add_argument("--actor_update_freq", type=int, default=1)
    parser.set_defaults(ob_norm=True)
    parser.set_defaults(
        evaluate_interval=40
    )  # change evaluation_interval from 20 to 40 to save running time TODO: change back after all the parameters are tuned
    parser.set_defaults(ckpt_interval=500)

    parser.add_argument("--target_kl", type=float, default=None)
    parser.add_argument(
        "--share_mlp",
        type=str2bool,
        default=False,
        help="whether to share mlp's weights for actor and critic",
    )


def add_dqn_arguments(parser):
    add_rl_arguments(parser)
    add_off_policy_arguments(parser)

    parser.add_argument("--critic_target_update_freq", type=int, default=2)
    parser.set_defaults(critic_soft_update_weight=0.995)
    parser.add_argument("--max_grad_norm", type=float, default=40.0)

    # epsilon greedy
    parser.add_argument("--epsilon_greedy", type=str2bool, default=False)
    parser.add_argument("--epsilon_greedy_eps", type=float, default=0.3)
    parser.add_argument("--policy_exploration_noise", type=float, default=0.1)

    parser.set_defaults(gaussian_policy=False)
    parser.set_defaults(ob_norm=False)

    parser.set_defaults(evaluate_interval=10000)
    parser.set_defaults(ckpt_interval=50000)
    parser.set_defaults(log_interval=1000)


def add_ddpg_arguments(parser):
    add_rl_arguments(parser)
    add_off_policy_arguments(parser)

    parser.add_argument("--actor_update_delay", type=int, default=2000)
    parser.add_argument("--actor_update_freq", type=int, default=2)
    parser.add_argument("--actor_target_update_freq", type=int, default=2)
    parser.add_argument("--critic_target_update_freq", type=int, default=2)
    parser.add_argument(
        "--actor_soft_update_weight",
        type=float,
        default=0.995,
        help="the average coefficient",
    )
    parser.set_defaults(critic_soft_update_weight=0.995)
    parser.add_argument("--max_grad_norm", type=float, default=40.0)

    # epsilon greedy
    parser.add_argument("--epsilon_greedy", type=str2bool, default=False)
    parser.add_argument("--epsilon_greedy_eps", type=float, default=0.3)
    parser.add_argument("--policy_exploration_noise", type=float, default=0.1)

    parser.set_defaults(gaussian_policy=False)
    parser.set_defaults(ob_norm=False)

    parser.set_defaults(evaluate_interval=10000)
    parser.set_defaults(ckpt_interval=50000)
    parser.set_defaults(log_interval=1000)


def add_td3_arguments(parser):
    add_ddpg_arguments(parser)

    parser.set_defaults(critic_ensemble=2)

    parser.add_argument("--policy_noise", type=float, default=0.2)
    parser.add_argument("--policy_noise_clip", type=float, default=0.5)


def add_il_arguments(parser):
    parser.set_defaults(evaluate_bc_test_loss=True)

    parser.add_argument("--demo_path", type=str, default=None, help="path to demos")
    parser.add_argument(
        "--demo_low_level",
        type=str2bool,
        default=False,
        help="use low level actions for training",
    )
    parser.add_argument(
        "--demo_subsample_interval",
        type=int,
        default=1,
        # default=20, # used in GAIL
        help="subsample interval of expert transitions",
    )
    parser.add_argument(
        "--demo_sample_range_start", type=float, default=0.0, help="sample demo range"
    )
    parser.add_argument(
        "--demo_sample_range_end", type=float, default=1.0, help="sample demo range"
    )
    parser.add_argument(
        "--num_target_demos",
        type=int,
        default=None,
        help="how many target demos to use for training, if None, use all demos",
    )
    parser.add_argument(
        "--target_demo_path",
        type=str,
        default=None,
        help="path to target demos; every time load demos, check if they are from target task. if so, sampling target demos",
    )
    parser.add_argument(
        "--task_num_in_dataset",
        type=int,
        default=18,
        help="number of tasks in the multi-task dataset",
    )


def add_bc_arguments(parser):
    parser.set_defaults(gaussian_policy=False)
    parser.set_defaults(max_global_step=100)
    parser.set_defaults(bc_step=10)
    parser.set_defaults(
        evaluate_interval=1
    )  # change evaluation_interval from 100 to 1 to check how many demos are required for BC under our env
    parser.set_defaults(
        ob_norm=True
    )  # change from False to True to align with PPO's ob_norm
    parser.add_argument(
        "--bc_lr", type=float, default=1e-3, help="learning rate for bc"
    )
    parser.add_argument(
        "--val_split",
        type=float,
        default=0,
        help="how much of dataset to leave for validation set",
    )
    parser.add_argument(
        "--multitask_bc_loss_coeff",
        type=float,
        default=0.0,
        help="blend target bc loss with multitask bc loss",
    )
    parser.add_argument(
        "--mt_balance",
        type=float,
        default=0.5,
        help="Data balancing for MT BC, proportion of data sampled from other tasks",
    )

    ### pretrain bc on multitask dataset
    parser.add_argument(
        "--pretrain_BC", type=str2bool, default=False, help="whether to pretrain BC"
    )
    parser.add_argument("--pretrain_bc_max_step", type=int, default=100)


def add_discriminator_arguments(parser):
    parser.add_argument(
        "--is_frozen", type=str2bool, default=False, help="whether to train for dvd baseline instead of ablation"
    )
    parser.add_argument("--pre_dvd_step",default=50,type=int,)
    parser.add_argument(
        "--pretrain_discriminator",
        default=0,
        type=int,
        choices=[0, 1, 2],
        help="0-no pretrain, 1-pretrain discriminators for"
        "the other tasks as well, 2-only use"
        "target task as positive examples",
    )
    parser.set_defaults(max_global_step=int(7e6))
    parser.add_argument(
        "--blend_steps",
        type=int,
        default=10000,
        help="how many steps to blend the other task demos",
    )
    parser.add_argument(
        "--blend_ratio",
        type=float,
        default=0.0,
        help="how much to blend the policy data with other task demos",
    )
    parser.add_argument("--is_saliency_map", type=str2bool, default=False)
    # parser.add_argument("--is_contrastive_encoder", type=str2bool, default=False)
    parser.add_argument("--output_dim", type=int, default=32)
    # parser.add_argument("--one_hot_vector", type=str2bool, default=False)
    # D_i related arguments
    # parser.add_argument("--D_i", type=str2bool, default=False)
    parser.add_argument("--target_task_index_in_demo_path", type=int, default=0)
    # parser.add_argument("--D_i_coeff", type=float, default=0.01, help="r(s,a) = D_T(s, a) + D_i_coeff* sum D_i (s,a)")
    # parser.add_argument("--discriminator_reward_type", type=str, default="general", choices=["general", "specific"],)
    parser.add_argument("--discriminator_lr", type=float, default=1e-3)  # 1e-4->1e-3
    parser.add_argument("--discriminator_mlp_dim", type=str2intlist, default=[256, 256])
    parser.add_argument(
        "--discriminator_activation",
        type=str,
        default="tanh",
        choices=["relu", "elu", "tanh"],
    )
    parser.add_argument("--discriminator_update_freq", type=int, default=4)


def add_gail_arguments(parser):
    parser.add_argument("--gail_entropy_loss_coeff", type=float, default=0.0)
    parser.add_argument(
        "--gail_reward", type=str, default="d", choices=["vanilla", "gan", "d"]
    )
    
    parser.add_argument("--gail_no_action", type=str2bool, default=False)
    parser.add_argument("--gail_env_reward", type=float, default=0.0)
    parser.add_argument("--gail_grad_penalty_coeff", type=float, default=10.0)
    parser.add_argument(
        "--gail_rl_algo",
        type=str,
        default="ppo",
        choices=["ppo", "sac", "ddpg", "td3", "dqn"],
    )
    add_discriminator_arguments(parser)


def add_gail_v1_arguments(parser):
    add_gail_arguments(parser)
    parser.set_defaults(batch_size=128)
    parser.add_argument("--lstm_hidden_dim", type=int, default=128)
    parser.add_argument("--lstm_num_layers", type=int, default=2)
    parser.add_argument("--traj_length", type=int, default=50)
    parser.add_argument(
        "--traj_interval",
        type=int,
        default=1,
        help="interval for sampling H step trajectories",
    )
    parser.add_argument(
        "--average_num",
        type=int,
        default=1,
        help="agerage over how many trajectories. don't use this for now",
    )


def add_reachable_arguments(parser):
    parser.add_argument("--reachability_threshold", type=float, default=0.8)
    parser.add_argument("--Dr_ensemble_size", type=int, default=1)
    args, unparsed = parser.parse_known_args()
    if args.gail_rl_algo == "sac":
        parser.set_defaults(buffer_size=int(1e6))
    else:
        parser.add_argument(
            "--buffer_size", type=int, default=int(1e6), help="the size of the buffer"
        )
    parser.add_argument(
        "--reach_discriminator_input",
        type=int,
        default=0,
        choices=[0, 1, 2],
        help="0: s' only, 1: s' + a', 2: s' + a ",
    )

    # parser.add_argument("--reach_reg_coeff", type=float, default=0.1,)

    ### backward relabelling arguments
    parser.add_argument(
        "--backwards_relabelling",
        type=str2bool,
        default=True,
        help="whether to use backwards relabelling for training and eval",
    )
    parser.add_argument(
        "--relabel_skip_threshold",
        type=float,
        default=1.0,
        help="if sigmoid(Prox(state)) > relabel_skip_threshold, skip the state and don't relabel its reward",
    )
    parser.add_argument("--k", type=int, default=None, help="sample every k steps for backward relabelling")
    parser.add_argument("--k_divide", type=int, default=5)
    parser.add_argument("--is_decay_K", type=str2bool, default=False)
    parser.add_argument(
        "--decay_method",
        type=str,
        default="linear",
        choices=["linear", "exp",],)
    
    ### pretraining arguments
    parser.add_argument(
        "--is_all_exp_in_reach_buffer",
        type=str2bool,
        default=False,
        help="whether to put all expert data in reach buffer for pretraining",
    )
    parser.add_argument("--pretrain_prox", type=str2bool, default=True)
    parser.add_argument(
        "--pretrain_n_epochs", type=int, default=5
    )  # 50 is too large and method already converges in 5 epochs
    parser.add_argument("--pretrain_inner_n_epochs", type=int, default=50)
    parser.add_argument("--pretrain_warmup_inner_n_epochs", type=int, default=100)
    parser.add_argument(
        "--reachability_update_frequency", type=int, default=50, help="fine tune"
    )  ## set to pretrain_inner_n_epochs for once per training epoch
    parser.add_argument(
        "--reachability_update_steps", type=int, default=-1
    )  ## -1 for entire dataset

    ### visulization arguments
    parser.add_argument("--eval_ood_states_reachability", type=str2bool, default=False)
    parser.add_argument(
        "--eval_ood_states_reachability_interval", type=int, default=500000
    )
    parser.add_argument(
        "--ood_traj_path",
        type=str,
        default=None,
        help="path to ood trajs used for evaluation of reachable discriminator",
    )
    parser.add_argument("--is_visual_reachability", type=str2bool, default=False)

    ### finetuning arguments
    parser.add_argument("--input_reachdata2replaybuffer", type=str2bool, default=False)
    parser.add_argument(
        "--reachability_discriminator_update_freq",
        type=int,
        default=4,
        help="fine tune",
    )
    parser.add_argument(
        "--reachability_online_update_steps", type=int, default=4, help="fine tune"
    )  ## -1 for entire dataset
    parser.add_argument(
        "--reachability_online_update_frequency", type=int, default=1, help="fine tune"
    )  ## set to pretrain_inner_n_epochs for once per training epoch
    parser.add_argument(
        "--reach_reward",
        type=str,
        default="post_sigmoid",
        choices=[
            "sum",
            "sum_negative",
            "tier_sum",
            "reach_only",
            "multiplicative",
            "sum_positive",
            "post_sigmoid"
        ],
    )
    parser.add_argument("--gail_discriminator_threshold", type=float, default=-0.5)
    parser.add_argument(
        "--is_othertask_demos2unreach_buffer",
        type=str2bool,
        default=True,
        help="if input other task demos into unreach buffer",
    )

    ### store trajectories arguments
    parser.add_argument("--presort_policy_samples", type=str2bool, default=False)
    parser.add_argument("--is_DVD_relabel", type=str2bool, default=True)
    parser.add_argument("--dvd_relabel_threshold", type=float, default=0.8, help="threshold for dvd relabelling",)
    parser.add_argument("--is_dvd_relabel_anneal", type=str2bool, default=False)
    parser.add_argument("--dvd_relabel_sanity_check", type=str2bool, default=False)

    ### dense reward arguments
    parser.add_argument("--dense_reward_type", type=str, default=None, choices=["max"])
    parser.add_argument(
        "--dense_reward_scale",
        type=float,
        default=0.001,
        help="scale the dense reward for backward relabelling",
    )
    parser.add_argument(
        "--backward_relabel_threshold",
        type=float,
        default=0.0,
        help="threshold for backward relabelling",
    )
    parser.add_argument(
        "--reach_reward_scale",
        type=float,
        default=6.0,
        help="scale reach reward when combining with GAIL",
    )
    parser.add_argument(
        "--reach_reward_constant",
        type=float,
        default=1.0,
        help="reach reward subtract a constant",
    )
    parser.add_argument(
        "--gail_reward_scale",
        type=float,
        default=1.0,
        help="scale GAIL reward when combining with Proximity",
    )
    
    ### for single task training
    parser.add_argument("--single_task_training", type=str2bool, default=False)
    parser.add_argument(
        "--add_more_demos2reachbuffer",
        type=int,
        default=None,
        help="add more demos to reach buffer for verifying prox.'s ability",
    )
    


def add_acgail_arguments(parser):
    parser.add_argument("--gail_entropy_loss_coeff", type=float, default=0.0)
    parser.add_argument(
        "--gail_reward", type=str, default="d", choices=["vanilla", "gan", "d"]
    )
    parser.add_argument("--discriminator_lr", type=float, default=1e-3)  # 1e-4->1e-3
    parser.add_argument("--discriminator_mlp_dim", type=str2intlist, default=[256])
    parser.add_argument("--discriminator_output_dim", type=int, default=256)
    parser.add_argument(
        "--discriminator_activation",
        type=str,
        default="tanh",
        choices=["relu", "elu", "tanh"],
    )
    parser.add_argument("--discriminator_update_freq", type=int, default=4)
    parser.add_argument("--gail_no_action", type=str2bool, default=False)
    parser.add_argument("--gail_env_reward", type=float, default=0.0)
    parser.add_argument("--gail_grad_penalty_coeff", type=float, default=10.0)

    parser.add_argument(
        "--gail_rl_algo", type=str, default="ppo", choices=["ppo", "sac", "td3"]
    )
    parser.add_argument("--label_taskID", type=str2bool, default=False)
    parser.add_argument("--label_goal_obs", type=str2bool, default=False)
    parser.add_argument(
        "--target_demo_path",
        type=str,
        default=None,
        help="path to target task demos used for "
        "changing the goal of other tasks' demos",
    )  # don't use acgail anymore, borrow the argument name

    add_discriminator_arguments(parser)


def add_airl_arguments(parser):
    parser.add_argument(
        "--discriminator_activation",
        type=str,
        default="tanh",
        choices=["relu", "elu", "tanh"],
    )
    parser.add_argument(
        "--discriminator_mlp_dim_r", type=str2intlist, default=[256, 256]
    )
    parser.add_argument(
        "--discriminator_mlp_dim_v", type=str2intlist, default=[256, 256]
    )
    parser.add_argument("--gamma", type=float, default=0.995)
    parser.add_argument(
        "--gail_rl_algo", type=str, default="ppo", choices=["ppo", "sac", "td3"]
    )

    parser.add_argument("--gail_env_reward", type=float, default=0.0)
    parser.add_argument("--discriminator_update_freq", type=int, default=4)
    parser.add_argument("--gail_entropy_loss_coeff", type=float, default=0.0)
    parser.add_argument(
        "--gail_reward", type=str, default="d", choices=["vanilla", "gan", "d"]
    )
    parser.add_argument("--discriminator_lr", type=float, default=1e-4)

    parser.add_argument("--gail_no_action", type=str2bool, default=False)
    parser.add_argument("--gail_grad_penalty_coeff", type=float, default=10.0)
    add_discriminator_arguments(parser)


def add_dac_arguments(parser):
    add_gail_arguments(parser)
    parser.set_defaults(gail_rl_algo="td3")
    parser.set_defaults(absorbing_state=True)
    parser.set_defaults(warm_up_steps=1000)
    parser.set_defaults(actor_lr=1e-3)
    parser.set_defaults(actor_update_delay=1000)
    parser.set_defaults(batch_size=100)
    parser.set_defaults(gail_reward="d")


def add_iqlearn_arguments(parser):
    parser.add_argument(
        "--iq_learn_alpha", type=float, default=0.5
    )  # from code: https://github.com/Div99/IQ-Learn/blob/main/iq_learn/conf/method/iq.yaml
    parser.add_argument("--iq_learn_use_target", type=str2bool, default=True)
    parser.add_argument("--separate_V", type=str2bool, default=True)
    parser.add_argument(
        "--reward_reg", type=str, default="chi2", choices=["chi2", "abs"]
    )

    parser.add_argument("--oldcode", type=str2bool, default=True)

    parser.add_argument("--BC_loss_coeff", type=float, default=0.0)


def add_sqil_arguments(parser):
    parser.set_defaults(max_global_step=int(7e6))
    parser.add_argument("--other_task_data_proportion", type=float, default=0)
    parser.set_defaults(gail_rl_algo="sac")
    parser.set_defaults(is_relabel_rew=False)


def argparser():
    """Directly parses the arguments."""
    parser = create_parser()
    args, unparsed = parser.parse_known_args()

    return args, unparsed
