
"""
An instace of integrating Flatland into MARLLib
"""

import os, sys
import time
import json
import argparse

from marllib import marl
from marllib.envs.flatland_env.metric import FlatlandCallbacks

cur_path = os.path.dirname(os.path.abspath(__file__))
callbacks = FlatlandCallbacks

parser = argparse.ArgumentParser()
# environment configs
parser.add_argument("--map_name", type=str, default='test_00')
parser.add_argument("--obs_builder", type=str, default='fast_tree')
parser.add_argument("--seed", type=int, default=2023)
# model configs
parser.add_argument("--model_path", type=str, default='ippo_mlp_test_02-independent_default/')
args = parser.parse_args()


def run_test(args, params_path, model_path, test_episodes=20):
    
    # prepare the environment 
    map_name = args.map_name + '-test'
    env = marl.make_env(environment_name="flatland", 
                        map_name=map_name,
                        obs_builder=args.obs_builder,
                        shared_reward='independent',
                        seed=args.seed)

    # initialize algorithm and load hyperparameters
    hyperparam_source = 'flatland'
    alg_name = args.model_path.split('_')[0]
    if alg_name == 'ippo':
        alg = marl.algos.ippo(hyperparam_source=hyperparam_source)

    elif alg_name == 'mappo':
        alg = marl.algos.mappo(hyperparam_source=hyperparam_source)

    # build agent model based on env + algorithms + user preference if checked available
    core_arch = args.model_path.split('_')[1]
    model = marl.build_model(env, alg, {"core_arch": core_arch})

    results = alg.render(env, model, 
            restore_path={'params_path': params_path,
                           'model_path': model_path,
                           'render': False},
             local_mode=True,
             share_policy="all",
             checkpoint_end=False,
             num_workers=1,
             evaluation_num_episodes=test_episodes, callbacks=callbacks
             )
    return results


if __name__ == '__main__':

    print(f'\n=== Start testing {args.model_path} on {args.map_name}...')

    exp_dir = os.path.join(f'{cur_path}/exp_results/', args.model_path)

    for run_name in os.listdir(exp_dir):
         run_dir = os.path.join(exp_dir, run_name)
         if os.path.isdir(run_dir):
             params_path = os.path.join(run_dir, 'params.json')
             checkpoints = [c for c in os.listdir(run_dir) if os.path.isdir(os.path.join(run_dir, c))]
             checkpoint_path = os.path.join(run_dir, sorted(checkpoints, reverse=True)[0])
             model_path = os.path.join(checkpoint_path, [c for c in os.listdir(checkpoint_path) if c.endswith('.tune_metadata')][0].split('.')[0])

             print(f'- model_path: {model_path}\n')
             run_test(args, params_path, model_path, test_episodes=20)
             args.seed += 1




