"""
Script to train the reasoner model.

Usage:
    train_reasoner.py [options]

Options:
    -h --help              Show this screen.

    --load-model-from LFM  Path to the model to be loaded.
"""

import os
import schema
import  pytorch_lightning as pl
from docopt import docopt
from models.policy_net import PolicyNet
from envs.tsp_env import TSPDistCost
from models.ppo_mask_recurrent import MaskableRecurrentPPO
from functools import partial
from utils.callbacks import evaluate_policy

if __name__ == '__main__':
    schema = schema.Schema({
        '--help': bool,
        '--load-model-from': schema.Or(None, os.path.exists)
    })

    args = docopt(__doc__)
    args = schema.validate(args)

    ts_env = TSPDistCost('test',
                         mask_invalid_actions=True,
                         optimisation_mode='multiple_samples')
    _ = ts_env.reset()

    PolicyNetSB = partial(PolicyNet,
                          spec=ts_env.dataset.spec,
                          data=ts_env.state)

    model = MaskableRecurrentPPO(PolicyNetSB,
                                 ts_env)

    model.load(args['--load-model-from'])
    mean_rw, std_rw = evaluate_policy(model, ts_env, n_eval_episodes=100, deterministic=True)
    breakpoint()

