from ml_collections import config_dict


def get_config(algorithm_name):
    config = config_dict.ConfigDict()

    config.name = algorithm_name
    
    # config.nr_hidden_units_disc = 256
    # config.learning_rate_disc = 1e-05
    # config.nr_epochs_disc = 10 # Number of disc epochs
    # config.env_reward_frac = 0.0
    # config.handle_absorbing_states = True
    # config.gp_lambda = 0.1
    # config.gp_alpha = 0.5
    # config.data_path = "../expert_data/30_episodes/expert_dataset_Ant-v5_30_PPO.npz"


    config.hidden_layers = [512, 256]
    config.lr = 1e-4
    config.disc_lr = 5e-5
    # config.num_envs = 2048
    config.num_steps = 14
    config.total_timesteps = 75e6
    config.update_epochs = 4
    config.train_disc_interval = 3
    config.disc_minibatch_size = 2048
    config.proportion_env_reward = 0.0  # 0.0 means the policy receives no environment reward, just used for evaluation!
    config.n_disc_epochs = 10
    config.num_minibatches = 32
    config.gamma = 0.99
    config.gae_lambda = 0.95
    config.clip_eps = 0.2
    config.init_std = 0.125
    config.learnable_std = False
    config.ent_coef = 0.0
    config.disc_ent_coef = 0.0
    config.vf_coef = 0.5
    config.max_grad_norm = 0.5
    config.activation = "tanh"
    config.anneal_lr = False
    config.weight_decay = 0.0
    config.normalize_env = True
    config.debug = False
    config.n_seeds = 1  # while automatically take seeds from 1 to n_seeds
    config.vmap_across_seeds = True
    config.data_path = "../expert_data/30_episodes/expert_dataset_MjxUnitreeGo2_30_PPO.npz"
    config.mocap_data_path = "../expert_data/mocap_cache/"

    config.gp_lambda = 0.04
    config.reward_type = 'state-action' # options: state-action, state-based, shaped, shaped-sa, uncorrelated
    config.handle_absorbing_states = True


    config.validation_active = False
    config.validation_num_steps = 100
    config.validation_num_envs = 100
    config.validation_num = 10  # set to 0 to disable validation


    return config
