import argparse
from utils import ModelHandler
import torch
import pickle
from dataset import LongRangeDataset

parser = argparse.ArgumentParser()
parser.add_argument('--data_file', type=str, help="Location of data", 
    default='')
parser.add_argument('--model_file', type=str, help="Saved model to test", 
    default='')
parser.add_argument('--save_results', type=str, help="File to save results",
    default="results.pickle")
parser.add_argument('--force_cpu', action='store_true', 
    help="Force use of cpu")
args = parser.parse_args()

dataset = LongRangeDataset(args.data_file)

if args.force_cpu:
    device = torch.device('cpu')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = torch.load(open(args.model_file, "rb"))

optim = None

handler = ModelHandler(model, optim, dataset, device, '')
acc, results = handler.test_model("beam_search", split="test",
    record_results=True)

pickle.dump(results, open(args.save_results,'wb'))
