from ml_collections import config_dict


def get_config(algorithm_name):
    config = config_dict.ConfigDict()

    config.name = algorithm_name
    
    # config.device = "gpu"  # cpu, gpu
    # config.total_timesteps = 1e9
    # config.learning_rate = 3e-4
    # config.anneal_learning_rate = False
    # config.nr_steps = 2048
    # config.nr_epochs = 10
    # config.minibatch_size = 64
    # config.gamma = 0.99
    # config.gae_lambda = 0.95
    # config.clip_range = 0.2
    # config.entropy_coef = 0.001
    # config.critic_coef = 0.5
    # config.max_grad_norm = 0.5
    # config.std_dev = 1.0
    # config.nr_hidden_units = 256
    # config.evaluation_frequency = 204800  # -1 to disable
    # config.evaluation_episodes = 10

    config.hidden_layers = [512, 256]
    config.lr = 1e-4
    config.num_envs = 2048
    config.num_steps = 50
    config.total_timesteps = 10e7
    config.update_epochs = 4
    config.num_minibatches = 32
    config.gamma = 0.99
    config.gae_lambda = 0.95
    config.clip_eps = 0.2
    config.init_std = 0.8
    config.learnable_std = True
    config.ent_coef = 0.0
    config.vf_coef = 0.5
    config.max_grad_norm = 0.5
    config.activation = "tanh"
    config.anneal_lr = False
    config.weight_decay = 0.0
    config.normalize_env = True
    config.debug = False
    config.n_seeds = 1  # while automatically take seeds from 1 to n_seeds
    config.vmap_across_seeds = True
    config.validation_active = False
    config.validation_num_steps = 100
    config.validation_num_envs = 100
    config.validation_num = 10  # set to 0 to disable validation


    return config
