from ml_collections import config_dict


def get_config(algorithm_name):
    config = config_dict.ConfigDict()

    config.name = algorithm_name

    config.device = "gpu"  # cpu, gpu
    config.nr_parallel_seeds = 1
    config.total_timesteps = 2e9
    config.learning_rate = 4e-4
    config.anneal_learning_rate = True
    config.nr_steps = 128
    config.nr_epochs = 10
    config.minibatch_size = 32768
    config.gamma = 0.99
    config.gae_lambda = 0.97
    config.clip_range = 0.1
    config.entropy_coef = 0.0001
    config.critic_coef = 1.0
    config.max_grad_norm = 10.0
    config.std_dev = 1.0
    config.action_clipping_and_rescaling = False
    config.evaluation_and_save_frequency = 17301504  # -1 to disable
    config.evaluation_active = True

    # NCSN Params
    config.batch_size_ncsn = 256
    config.minibatch_size_ncsn = 64
    config.total_samples_ncsn = 25e6
    config.nr_epochs_ncsn = 20 # Number of ncsn epochs
    config.anneal_power_ncsn = 2.0
    config.sigma_begin_ncsn = 10.0
    config.sigma_end_ncsn = 0.01
    config.L_ncsn = 20
    config.nr_hidden_units_encoder_ncsn = [256, 512, 1024, 2048]
    config.nr_hidden_units_decoder_ncsn = [1024, 512, 128, 64, 32]
    config.learning_rate_ncsn = 0.00011787105541232714
    # config.sigma_inference_ncsn = -1
    config.sigma_inference_ncsn = 5
    config.anneal_threshold = 0.008
    config.env_reward_frac = 0.0
    config.handle_absorbing_states = True
    config.state_based = False
    config.ncsnv1 = True
    config.data_path = "../expert_data/30_episodes/expert_dataset_Ant-v5_30_PPO.npz"


    return config
