
"""
An instace of integrating Flatland into MARLLib
"""

import os, sys
import time
import json
import argparse

from marllib import marl
from marllib.envs.flatland_env.metric import FlatlandCallbacks

cur_path = os.path.dirname(os.path.abspath(__file__))
callbacks = FlatlandCallbacks


parser = argparse.ArgumentParser()
# environment configs
parser.add_argument("--map_name", type=str, default='test_00')
parser.add_argument("--obs_builder", type=str, default='fast_tree')
parser.add_argument("--seed", type=int, default=2023)
# algorithm configs
parser.add_argument("--alg_name", type=str, default='ippo')
# model configs 
parser.add_argument("--core_arch", type=str, default='tree')
parser.add_argument("--params_path", type=str, default='tree')
parser.add_argument("--model_path", type=str, default='tree')
args = parser.parse_args()

# import pyvirtualdisplay
# # Creates a virtual display for OpenAI gym
# pyvirtualdisplay.Display(visible=0, size=(800, 600)).start()
# prepare the environment 
args.map_name = args.map_name + '-videos'
env = marl.make_env(environment_name="flatland", 
                    map_name=args.map_name,
                    obs_builder=args.obs_builder,
                    shared_reward='independent',
                    seed=args.seed)
env[0].metadata['render.modes'] = ["rgb_array"]

# TODO: decide hyperparam based on alg, env, model
# initialize algorithm and load hyperparameters
hyperparam_source = 'flatland'
if args.alg_name == 'ippo':
    alg = marl.algos.ippo(hyperparam_source=hyperparam_source)

elif args.alg_name == 'mappo':
    alg = marl.algos.mappo(hyperparam_source=hyperparam_source)

# build agent model based on env + algorithms + user preference if checked available
model = marl.build_model(env, alg, {"core_arch": args.core_arch})

alg.render(env, model, 
            restore_path={'params_path': f"{cur_path}/exp_results/{args.params_path}",  # experiment configuration
                           'model_path': f"{cur_path}/exp_results/{args.model_path}",
                           'render': True},
             local_mode=True,
             share_policy="all",
             checkpoint_end=False,
             num_workers=1, 
             evaluation_num_episodes=100, callbacks=callbacks
             )
