
import omnisafe
from envs.seaquest import Seaquest
import argparse

env_id = 'SeaquestCMDP-v0'
algo_id = 'PPOLag'
seed = 0

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-id", default=env_id)
    parser.add_argument("--algo-id", default=algo_id)
    parser.add_argument("--seed", default=seed)
    args = parser.parse_args()

    try:
        env_id = str(args.env_id)
        algo_id = str(args.algo_id)
        seed = int(args.seed)
    except:
        raise TypeError

    parameter_dict = {
        'SeaquestCMDP-v0' : {
            'total_steps': 5000000,
            'gamma': 0.9967,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 1.0, 
            'log_dir' : 'omnisafe_seaquest_property_1/',
        },
        'SeaquestCMDP-v1' : {
            'total_steps': 5000000,
            'gamma': 0.9967,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 1.0, 
            'log_dir' : 'omnisafe_seaquest_property_2/',
        },
    }

    if algo_id == 'PPOLag':
        custom_cfgs = {
            'seed': seed,
            'train_cfgs': {
                'total_steps': parameter_dict[env_id]['total_steps'],
                'vector_env_nums': 1,
                'parallel': 1,
            },
            'algo_cfgs': {
                'steps_per_epoch': 20000, #[For additional experiments]: 20000
                'update_iters': 40, #[For additional experiments]: 40
                'gamma': parameter_dict[env_id]['gamma'],
                'cost_gamma': parameter_dict[env_id]['gamma'],
                'lam' : 0.95,
                'lam_c' : 0.95,
                'batch_size': 64,
                'use_max_grad_norm': True,
                'max_grad_norm': 40.0, #[For additional experiments]: 40.0
                'kl_early_stop': False,
                'clip':0.2,
                'entropy_coef': 0.0,
            },
            'logger_cfgs':{
                'log_dir' : parameter_dict[env_id]['log_dir'] + env_id +f'_seed_{seed}',
                'use_tensorboard' : True,
                'window_lens': 100,
            },
            'model_cfgs':{
                'actor': {
                    'lr' : 3e-5,
                },
            },
            'lagrange_cfgs': {
                'cost_limit': parameter_dict[env_id]['cost_limit'],
                'lagrangian_multiplier_init':parameter_dict[env_id]['lagrangian_multiplier_init'],
            },
        }
    elif algo_id == 'CPO':
        custom_cfgs = {
            'seed': seed,
            'train_cfgs': {
                'total_steps': parameter_dict[env_id]['total_steps'],
                'device': 'cuda:0',
                'vector_env_nums': 1,
                'parallel': 1,
            },
            'algo_cfgs': {
                'cost_limit': parameter_dict[env_id]['cost_limit'],
                'steps_per_epoch': 20000, #[For additional experiments]: 20000
                'update_iters': 10,
                'gamma': parameter_dict[env_id]['gamma'],
                'cost_gamma': parameter_dict[env_id]['gamma'],
                'lam' : 0.95,
                'lam_c' : 0.95,
                'batch_size': 128, #[For additional experiments]: 128
                'use_max_grad_norm': True,
                'max_grad_norm': 40.0, #[For additional experiments]: 40.0
                'kl_early_stop': False,
                'entropy_coef': 0.0,
            },
            'logger_cfgs':{
                'log_dir' : parameter_dict[env_id]['log_dir'] + env_id +f'_seed_{seed}',
                'use_tensorboard' : True,
                'window_lens': 100,
            },
            'model_cfgs':{
                'actor': {
                    'lr' : 3e-5,
                },
            },
        }

    else:
        raise NotImplementedError

    agent = omnisafe.Agent(algo_id, env_id, custom_cfgs=custom_cfgs)
    agent.learn()