from yacs.config import CfgNode as CN

_C = CN()

_C.evaluate = False
_C.save_model = True
_C.load_model = False
_C.gym_type = 'gymnasium'

# WANDB settings
_C.wandb = CN()
_C.wandb.enable = False
_C.wandb.project_name = ''
_C.wandb.entity_name = 'mechanistic_offline_rl'
_C.wandb.name = ''

# Creating behavioral policy data args
_C.train = CN()
_C.train.agent_path = ''
_C.train.algorithm = ''
_C.train.eval_freq = 5e3
_C.train.save_freq = 1e5
_C.train.start_timesteps = 100
_C.train.batch_size = 256
_C.train.max_timesteps = 1e5
_C.train.learning_rate = 1e-3
_C.train.gamma = 0.95

# Policy specific args
_C.policy = CN()
_C.policy.expl_noise = 0.1
_C.policy.discount = 0.99
_C.policy.tau = 0.005
_C.policy.policy_noise = 0.2
_C.policy.noise_clip = 0.5
_C.policy.policy_freq = 2
_C.policy.normalize = True

# Simulator model args
_C.simulator = CN()
_C.simulator.transform_list = []

# System parameters
_C.system = CN()
_C.system.seed = 11
_C.system.cpu = False

# Environment
_C.env = CN()
_C.env.train_env = ''
_C.env.eval_env = ''


def get_cfg_defaults(config_file=None, config_list=None):
    cfg = _C.clone()
    if config_file is not None:
        cfg.merge_from_file(config_file)
    if config_list is not None:
        cfg.merge_from_list(config_list)
    cfg.freeze()
    return cfg
