"""
Script to test the TSP.

Usage:
    test_tsp.py (--load-model-from LFM) [options]

Options:
    -h --help              Show this screen.

    --load-model-from LFM  Path to the model to be loaded

"""
import os
from docopt import docopt
import schema
import tqdm
import pytorch_lightning as pl
from torch_geometric.loader import DataLoader
from datasets.tsp_datasets import TSPLIB

from models.algorithm_processor import LitAlgorithmProcessor
from hyperparameters import get_hyperparameters
from datasets.constants import _DATASET_ROOTS

if __name__ == '__main__':
    serialised_models_dir = os.path.abspath('./serialised_models/')
    hidden_dim = get_hyperparameters()['dim_latent']
    schema = schema.Schema({
        '--help': bool,
        '--load-model-from': schema.Or(None, os.path.exists),
    })
    args = docopt(__doc__)
    args = schema.validate(args)

    lit_processor = LitAlgorithmProcessor.load_from_checkpoint(
        args['--load-model-from'],
    )
    dataset = TSPLIB(root=_DATASET_ROOTS['tsplib'])

    dl = DataLoader(dataset,
                    batch_size=1,
                    shuffle=False,
                    drop_last=False,
                    follow_batch=['edge_index'],
                    num_workers=1,
                    persistent_workers=True)
    lit_processor.algorithms.tsp_large.algorithm_module.dataset_spec = dataset.spec

    errors = []
    for x in tqdm.tqdm(dl):
        out = lit_processor.algorithms.tsp_large.valtest_step(x, 0, 'test')
        errors.append(out['accuracies']['tour_relative_error'].item())

    print(sum(errors) / len(errors))
