from algorithms.ppo import PPO
from environments import make_vec_env
from setup import AttrDict, parse_arguments, set_seed, set_device, setup_logger


def get_config():
    config = AttrDict()
    config.algo = "ppo"
    config.env_id = "HalfCheetah-v2"
    config.expr_name = "default"
    config.seed = 0
    config.use_gpu = True
    config.pixel_obs = False

    # PPO
    config.num_steps = 2000000
    config.num_envs = 4
    config.rollout_length = 10000
    config.batch_size = 500
    config.train_epochs = 10
    config.hidden_size = 256
    config.gamma = 0.99
    config.gae_lambda = 0.95
    config.policy_lr = 1e-4  # 1e-5 for cnn
    config.value_lr = 1e-4  # 1e-5 for cnn
    config.clip_range = 0.1
    config.ent_coef = 0.0
    config.eval_freq = 10  # in number of epochs
    config.checkpoint_freq = 10  # in number of epochs
    config.num_eval_episodes = 16
    return parse_arguments(config)


if __name__ == "__main__":
    config = get_config()
    set_seed(config.seed)
    set_device(config.use_gpu)

    # Logger
    logger = setup_logger(config)

    # Environment
    envs = make_vec_env(
        config.env_id,
        config.num_envs,
        config.seed,
        config.pixel_obs,
    )
    eval_env_id = config.env_id.replace("train", "test")
    eval_seed = config.seed + 1000
    eval_envs = make_vec_env(eval_env_id, 1, eval_seed, config.pixel_obs)

    # Agent
    algo = PPO(config, envs, eval_envs, logger)
    algo.train()
