from ml_collections import config_dict


def get_config(algorithm_name):
    config = config_dict.ConfigDict()

    config.name = algorithm_name

    config.device = "gpu"  # cpu, gpu
    config.nr_parallel_seeds = 1
    config.total_timesteps = 2e9
    config.learning_rate = 4e-4
    config.anneal_learning_rate = True
    config.nr_steps = 128
    config.nr_epochs = 10
    config.minibatch_size = 32768
    config.gamma = 0.99
    config.gae_lambda = 0.9
    config.clip_range = 0.1
    config.entropy_coef = 0.001
    config.critic_coef = 1.0
    config.max_grad_norm = 5.0
    config.std_dev = 1.0
    config.action_clipping_and_rescaling = False
    config.evaluation_and_save_frequency = 17301504  # -1 to disable
    config.evaluation_active = True

    # TRIRL Params
    config.nr_hidden_units_disc = 256
    config.learning_rate_disc = 0.0002015531860350999
    config.nr_epochs_disc = 30  # Number of disc epochs
    config.env_reward_frac = 0.0
    config.data_path = "../expert_data/30_episodes/expert_dataset_Ant-v5_30_PPO.npz"

    config.epsilon = 0.6931393089159753
    config.disc_buffer_capacity = 100
    # config.init_eta = 10.0
    config.beta = float(1/config.entropy_coef)

    config.mean_bound = 0.0002042444272419476
    config.cov_bound = 0.0044635849738194145
    config.trust_region_coef = 0.6815786322202716
    config.nr_epochs_rew = 30
    config.learning_rate_reward_fn = 5.7797838797843984e-05
    config.gp_lambda = 0.03267622091691947
    config.gp_alpha = 0.5

    config.handle_absorbing_states = True
    config.reward_fn_approximator = True
    config.on_demand_etas = True
    config.reward_type = 'state-action' # options: state-action, state-based, shaped, shaped-sa, uncorrelated
    config.reward_approximator_type = 'state-action' # options: state-action, state-based, shaped, shaped-sa
    # config.maximum_eta = True # automatically set to false if using on_demand_etas

    # config.global_rew_experiment = False
    # config.subsampling_cutoff = 1

    return config
