

def build_lompo_config(conf, env_name, dataset_type, risky):
    if env_name == 'walker_walk':
        if risky:
            conf.hidden_latent = 256
            conf.critic_num_layers = 3
            conf.critic_units = 256
            conf.act_critic = 'relu'
            conf.act_actor = 'relu'
            conf.num_layers_actor = 4
            conf.units_actor = 256

            conf.latent_batch_size = 64
            conf.latent_batch_length = 50
            conf.horizon = 5

            conf.discount = 0.99
            conf.clip_rewards = 'none'
            conf.num_layers_rewards = 3
            conf.num_units_reward = 128

            conf.weight_decay = 0.0
            conf.grad_clip = 100

            conf.latent_model_lr = 6e-4
            conf.critic_lr = 3e-4
            conf.actor_lr = 3e-4
            conf.log_every = 1000
            conf.kl_scale = 1.0
            conf.actor_batch_size = 256

            conf.critic_hidden_units_state = 256
            conf.critic_hidden_units_action = 256

            conf.critic_gamma = 0.99
            conf.target_update_interval = 1
            conf.tau = 5e-3     # for soft update

            if dataset_type == 'expert':
                conf.num_actor_critic_loop = 2050
                return conf

            if dataset_type == 'medium':
                conf.num_actor_critic_loop = 1100
                return conf

            if dataset_type == 'expert_replay':
                conf.num_actor_critic_loop = 800
                return conf


        if not risky:
            conf.hidden_latent = 256
            conf.critic_num_layers = 3
            conf.critic_units = 256
            conf.act_critic = 'relu'
            conf.act_actor = 'relu'
            conf.num_layers_actor = 4
            conf.units_actor = 256

            conf.latent_batch_size = 64
            conf.latent_batch_length = 50

            conf.discount = 0.99
            conf.clip_rewards = 'none'
            conf.num_layers_rewards = 3
            conf.num_units_reward = 128

            conf.weight_decay = 0.0
            conf.grad_clip = 100

            conf.latent_model_lr = 6e-4
            conf.critic_lr = 3e-4
            conf.actor_lr = 3e-4
            conf.log_every = 1000
            conf.kl_scale = 1.0
            conf.actor_batch_size = 256        # 256

            conf.critic_hidden_units_state = 256
            conf.critic_hidden_units_action = 256

            conf.critic_gamma = 0.99
            conf.target_update_interval = 1
            conf.tau = 5e-3     # for soft update
            if dataset_type == 'expert':
                conf.num_actor_critic_loop = 650
                return conf

            if dataset_type == 'medium':
                conf.num_actor_critic_loop = 1250
                return conf

            if dataset_type == 'expert_replay':
                conf.num_actor_critic_loop = 1500
                return conf


        else:
            print("We dont have any configuration for lompo and this type of dataset.")
            raise NotImplementedError(dataset_type)
    else:
        print("We dont have any configuration for lompo and this environment.")
        raise NotImplementedError(env_name)

