from ml_collections.config_dict import config_dict
import ml_collections

def td_get_config():
    config = ml_collections.ConfigDict()

    config.actor_lr = 3e-4
    config.critic_lr = 3e-4

    config.hidden_dims = (256, 256)

    config.discount = 0.99

    config.num_qs = 2

    config.tau = 0.005

    config.critic_layer_norm = False

    return config


def get_sac_config():
    config = td_get_config()

    config.model_cls = "SACLearner"

    config.temp_lr = 3e-4

    config.init_temperature = 1.0
    config.target_entropy = config_dict.placeholder(float)

    config.backup_entropy = False
    config.critic_weight_decay = config_dict.placeholder(float)

    return config

def get_rlpd_config():
    config = get_sac_config()

    config.num_qs = 10
    config.num_min_qs = 2
    config.critic_layer_norm=True
    config.hidden_dims = (256, 256, 256)

    return config

def get_config():
    
    config = get_rlpd_config()
    
    config.model_cls = "RLPhiLearner"
    config.mb_ensemble_size = 5
    config.adjuster_hidden_dim = 64
    config.use_lstm = False
    
    return config