
"""
An instace of integrating Flatland into MARLLib
"""

import os, sys
import time
import json
import argparse

from marllib import marl
import utils

parser = argparse.ArgumentParser()
# environment configs
parser.add_argument("--map_name", type=str, default='test_00')
parser.add_argument("--obs_builder", type=str, default='fast_tree')
parser.add_argument("--shared_reward", type=str, default='global')
parser.add_argument("--seed", type=int, default=2023)
# algorithm configs
parser.add_argument("--alg_name", type=str, default='ippo')
# model configs 
parser.add_argument("--core_arch", type=str, default='tree')
# training configs 
parser.add_argument("--timesteps_total", type=int, default=int(2e7))
parser.add_argument("--num_workers", type=int, default=2)
parser.add_argument("--exp_name", type=str, default='default')
args = parser.parse_args()

# set seed
utils.set_seed_everywhere(args.seed)

# prepare the environment 
args.map_name = args.map_name + '-' + args.shared_reward  + '_' + args.exp_name
env = marl.make_env(environment_name="flatland", 
                    map_name=args.map_name,
                    obs_builder=args.obs_builder,
                    shared_reward=args.shared_reward,
                    seed=args.seed)

# TODO: decide hyperparam based on alg, env, model
# initialize algorithm and load hyperparameters
hyperparam_source = 'flatland'
if args.alg_name == 'ippo':
    alg = marl.algos.ippo(hyperparam_source=hyperparam_source)

elif args.alg_name == 'mappo':
    alg = marl.algos.mappo(hyperparam_source=hyperparam_source)
elif args.alg_name == 'vdppo':
    alg = marl.algos.vdppo(hyperparam_source=hyperparam_source)
elif args.alg_name == 'ia2c':
    alg = marl.algos.ia2c(hyperparam_source=hyperparam_source)
elif args.alg_name == 'coma':
    alg = marl.algos.coma(hyperparam_source=hyperparam_source)

# build agent model based on env + algorithms + user preference if checked available
model = marl.build_model(env, alg, {"core_arch": args.core_arch})

# start learning + extra experiment settings if needed. remember to check ray.yaml before use
from marllib.envs.flatland_env.metric import FlatlandCallbacks
callbacks = FlatlandCallbacks

alg.fit(env, model, 
        stop={'episode_reward_mean': 1000, 'timesteps_total': int(args.timesteps_total)}, 
        local_mode=False, num_gpus=1, num_workers=args.num_workers, 
        rollout_fragment_length=100, train_batch_size=1000, batch_episode=1,
        share_policy='all', checkpoint_freq=int(1e3), callbacks=callbacks)

