from sgornn import addition_problem
from sgornn import SGORNNModel, FastRNNModel, VPRNNModel, LSTMModel
from vprnn.models import save_vprnn
from argparse import ArgumentParser
import json

parser = ArgumentParser()
parser.add_argument('--model-type', type=str, default='sgornn',
                    help='the model type. one of [sgornn, vprnn, fastrnn, lstm]')
parser.add_argument('--sequence-length', type=int, default=500,
                    help='sequence length (T)')
parser.add_argument('--epochs', type=int, default=75,
                    help='epochs (batches x 100) used to train')
parser.add_argument('--lr', type=float, default=1e-2,
                    help='initial lr. decays to zero linearly.')
parser.add_argument('--model-output-path', type=str,
                    default=None,
                    help='if provided, save the model to this path')
parser.add_argument('--history-output-path',
                    type=str, default=None,
                    help='if provided, save the training history to this path')
parser.add_argument('--no-scalar-clip', action='store_true')

if __name__ == '__main__':
    args = parser.parse_args()
    model_type = ['sgornn', 'vprnn', 'fastrnn', 'lstm'].index(args.model_type)
    model = [SGORNNModel, VPRNNModel, FastRNNModel, LSTMModel][model_type](input_dim=2,
                                                                           output_dim=1,
                                                                           layers=1,
                                                                           clip_scalar=not args.no_scalar_clip,
                                                                           dim=128,
                                                                           rots=7,
                                                                           output_activation='linear')
    model.summary()

    history = addition_problem.fit(model,
                                   args.sequence_length,
                                   epochs=args.epochs,
                                   initial_lr=args.lr).history

    if args.model_output_path:
        if isinstance(model, (SGORNNModel, VPRNNModel)):
            save_vprnn(model, args.model_output_path)
        else:
            model.save(args.model_output_path)

    if args.history_output_path:
        json.dump(history, open(args.history_output_path, 'w'))
