from sgornn import ptb
from sgornn.models import FastRNNModel, VPRNNModel, SGORNNModel, LSTMModel
from keras.layers import Embedding, Dropout
from keras import Sequential
from argparse import ArgumentParser
from math import log2, ceil
import numpy as np

parser = ArgumentParser()
parser.add_argument('--model-type', type=str, default='sgornn',
                    help='the model type. one of [sgornn, vprnn, fastrnn]')
parser.add_argument('--units', type=int, default=256,
                    help='hidden state size (h)')
parser.add_argument('--layers', type=int, default=1,
                    help='number of layers (L)')
parser.add_argument('--lr', type=float, default=1e-3,
                    help='learning rate.')
parser.add_argument('--batch-size', type=int, default=100)
parser.add_argument('--no-scalar-clip', action='store_true')
parser.add_argument('--epochs', type=int, default=300)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num-words', type=int, default=10000)
parser.add_argument('--sequence-length', type=int, default=300)
parser.add_argument('--dropout', type=float, default=0.3)

if __name__ == '__main__':
    args = parser.parse_args()
    np.random.seed(args.seed)
    model_type = ['sgornn', 'vprnn', 'fastrnn', 'lstm'].index(args.model_type)
    embedding_layer = Embedding(args.num_words, args.units)
    rnn_model = [SGORNNModel, VPRNNModel, FastRNNModel, LSTMModel][model_type](input_dim=args.units,
                                                                               output_dim=args.num_words,
                                                                               output_activation='softmax',
                                                                               clip_scalar=not args.no_scalar_clip,
                                                                               layers=args.layers,
                                                                               dim=args.units,
                                                                               return_sequences=True,
                                                                               rots=ceil(log2(args.units)))
    print(rnn_model)
    model = Sequential()
    model.add(embedding_layer)
    if args.dropout:
        model.add(Dropout(args.dropout))
    for layer in rnn_model.layers:
        model.add(layer)
    model.summary()


    def schedule(epoch, _):
        if epoch >= 200:
            return args.lr / 10
        else:
            return args.lr


    ptb.fit(model,
            lr=args.lr,
            lr_scheduler=schedule,
            epochs=args.epochs,
            batch_size=args.batch_size,
            num_words=args.num_words,
            early_stopping=None,
            num_steps=args.sequence_length)

    print('TEST EVAL: ', ptb.evaluate(model,
                                      batch_size=args.batch_size,
                                      num_words=args.num_words,
                                      num_steps=args.sequence_length))
