
import omnisafe
from bridge_crossing_omnisafe import BridgeCrossing
from media_streaming_omnisafe import MediaStreaming
from colour_bomb_grid_world_omnisafe import ColourBombGridWorld
from colour_bomb_grid_world_v2_omnisafe import ColourBombGridWorldV2
import argparse

env_id = 'BridgeCrossing-v0'
algo_id = 'PPOLag'
seed = 0

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--env-id", default=env_id)
    parser.add_argument("--algo-id", default=algo_id)
    parser.add_argument("--seed", default=seed)
    args = parser.parse_args()

    try:
        env_id = str(args.env_id)
        algo_id = str(args.algo_id)
        seed = int(args.seed)
    except:
        raise TypeError

    parameter_dict = {
        'BridgeCrossing-v0' : {
            'total_steps': 300000, #[For additional experiments]: 1000000
            'gamma': 0.99,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.15, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_bridge_tensorboard/',
            'steps_per_epoch': 2000,
        },
        'BridgeCrossing-v2' : {
            'total_steps': 300000, #[For additional experiments]: 1000000
            'gamma': 0.99,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.15, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_bridge_v2_tensorboard/',
            'steps_per_epoch': 2000,
        },
        'MediaStreaming-v0' : {
            'total_steps': 25000, #[For additional experiments]: 500000
            'gamma':0.95,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.01, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_media_tensorboard/',
            'steps_per_epoch': 400,
        },
        'ColourBomb9x9-v0': {
            'total_steps': 100000, #[For additional experiments]: 1000000
            'gamma':0.95,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.01, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_colour_bomb_tensorboard/',
            'steps_per_epoch': 1000,
        },
        'ColourBomb9x9-v1': {
            'total_steps': 100000, #[For additional experiments]: 1000000
            'gamma':0.95,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.12, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_colour_bomb_tensorboard/',
            'steps_per_epoch': 1000,
        },
        'ColourBomb15x15-v0' : {
            'total_steps': 300000, #[For additional experiments]: 1000000
            'gamma':0.95,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.02, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_colour_v2_tensorboard/',
            'steps_per_epoch': 2500,
        },
        'ColourBomb15x15-v1' : {
            'total_steps': 300000, #[For additional experiments]: 1000000
            'gamma':0.95,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 1.2, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_colour_v2_tensorboard/',
            'steps_per_epoch': 2500,
        },
        'ColourBomb15x15-v2' : {
            'total_steps': 300000, #[For additional experiments]: 1000000
            'gamma':0.95,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.01, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_colour_v2_tensorboard/',
            'steps_per_epoch': 2500,
        },
        'ColourBomb15x15-v3' : {
            'total_steps': 300000, #[For additional experiments]: 1000000
            'gamma':0.95,
            'lagrangian_multiplier_init':10.0,
            'cost_limit': 0.01, #[For additional experiments]: 0.5
            'log_dir' : 'omnisafe_colour_v2_tensorboard/',
            'steps_per_epoch': 2500,
        },
    }

    if algo_id == 'PPOLag':
        custom_cfgs = {
            'seed': seed,
            'train_cfgs': {
                'total_steps': parameter_dict[env_id]['total_steps'],
                'vector_env_nums': 1,
                'parallel': 1,
            },
            'algo_cfgs': {
                'steps_per_epoch': parameter_dict[env_id]['steps_per_epoch'], #[For additional experiments]: 20000
                'update_iters': 40,
                'gamma': parameter_dict[env_id]['gamma'],
                'cost_gamma': parameter_dict[env_id]['gamma'],
                'lam' : 0.95,
                'lam_c' : 0.95,
                'batch_size': 64,
                'use_max_grad_norm': True,
                'max_grad_norm': 0.5,
                'kl_early_stop': False,
                'clip':0.2,
                'entropy_coef': 0.0,
            },
            'logger_cfgs':{
                'log_dir' : parameter_dict[env_id]['log_dir'] + env_id +f'_seed_{seed}',
                'use_tensorboard' : True,
                'window_lens': 100,
            },
            'model_cfgs':{
                'actor': {
                    'lr' : 0.0003,
                },
            },
            'lagrange_cfgs': {
                'cost_limit': parameter_dict[env_id]['cost_limit'],
                'lagrangian_multiplier_init':parameter_dict[env_id]['lagrangian_multiplier_init'],
            },
        }
    elif algo_id == 'CPO':
        custom_cfgs = {
            'seed': seed,
            'train_cfgs': {
                'total_steps': parameter_dict[env_id]['total_steps'],
                'device': 'cuda:0',
                'vector_env_nums': 1,
                'parallel': 1,
            },
            'algo_cfgs': {
                'cost_limit': parameter_dict[env_id]['cost_limit'],
                'steps_per_epoch': parameter_dict[env_id]['steps_per_epoch'], #[For additional experiments]: 20000
                'update_iters': 10,
                'gamma': parameter_dict[env_id]['gamma'],
                'cost_gamma': parameter_dict[env_id]['gamma'],
                'lam' : 0.95,
                'lam_c' : 0.95,
                'batch_size': 128,
                'use_max_grad_norm': True,
                'max_grad_norm': 0.5,
                'kl_early_stop': False,
                'entropy_coef': 0.0,
            },
            'logger_cfgs':{
                'log_dir' : parameter_dict[env_id]['log_dir'] + env_id +f'_seed_{seed}',
                'use_tensorboard' : True,
                'window_lens': 100,
            },
            'model_cfgs':{
                'actor': {
                    'lr' : 0.0003,
                },
            },
        }

    else:
        raise NotImplementedError

    agent = omnisafe.Agent(algo_id, env_id, custom_cfgs=custom_cfgs)
    agent.learn()