import argparse
from distutils.util import strtobool

ENV_ID = [
    'Hopper-v2',
    'Walker2d-v2',
    'HalfCheetah-v2',
    'Ant-v2',
]

def get_parser():
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--env_id', default='Hopper-v2', type=str, choices=ENV_ID)
    parser.add_argument('--dataset_dir', default='dataset', type=str)
    parser.add_argument('--expert_dataset_info', default=["expert-v2", 1])
    parser.add_argument('--suboptimal_dataset_name', nargs="+", default=["expert-v2", "random-v2"])
    parser.add_argument('--suboptimal_dataset_num', nargs="+", default=[400, 800],type=int)
    parser.add_argument('--resume', default=True, type=strtobool)

    parser.add_argument('--total_iterations', default=int(1e6), type=int)
    parser.add_argument('--save_interval', default=int(1e4), type=int)
    parser.add_argument('--log_interval', default=int(1e4), type=int)
    parser.add_argument('--critic_lr', default=3e-4, type=float)
    parser.add_argument('--actor_lr', default=3e-4, type=float)
    parser.add_argument('--gamma', default=0.99, type=float)
    parser.add_argument('--hidden_size', default=256, type=int)
    parser.add_argument('--batch_size', default=256, type=int)
    parser.add_argument('--using_absorbing', default=True, type=strtobool)
    parser.add_argument('--grad_reg_coeffs', nargs="+", default=[10.0, 1e-4], type=float)
    parser.add_argument('--use_last_layer_bias_cost', default=False, type=strtobool)
    parser.add_argument('--use_last_layer_bias_critic', default=False, type=strtobool)
    parser.add_argument('--kernel_initializer', default='he_normal', type=str)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--save_folder_name', default='', type=str)

    return parser