import numpy as np
from keras.preprocessing.text import Tokenizer
from keras.callbacks import LearningRateScheduler, EarlyStopping
from keras import Model
from keras.optimizers import RMSprop
import keras.backend as K
import io
import os
import wget
import zipfile


def ptb_supervised(raw_data, num_steps=300, batch_size=100, shuffle=True):
    raw_data = np.array(raw_data, dtype=np.int32)
    data_len = len(raw_data)

    X, Y = [], []
    inds = [num_steps * i for i in range(1, data_len // num_steps)]
    if shuffle:
        np.random.shuffle(inds)
    for ptr in inds:
        x = raw_data[ptr - num_steps:ptr]
        y = raw_data[ptr - num_steps + 1:ptr + 1]
        X.append(x)
        Y.append(y)
        if len(X) == batch_size:
            yield np.array(X), np.array(Y)
            X, Y = [], []

    if len(X):
        yield np.array(X), np.array(Y)


def load_sequences(num_words=10000):
    wget.download('https://data.deepai.org/ptbdataset.zip')

    with zipfile.ZipFile('ptbdataset.zip') as zf:
        with io.TextIOWrapper(zf.open('ptb.train.txt', 'r'), encoding="utf-8") as fp:
            train_raw = fp.read().replace('\n', '<eos>')
        with io.TextIOWrapper(zf.open('ptb.test.txt', 'r'), encoding="utf-8") as fp:
            test_raw = fp.read().replace('\n', '<eos>')
        with io.TextIOWrapper(zf.open('ptb.valid.txt', 'r'), encoding="utf-8") as fp:
            valid_raw = fp.read().replace('\n', '<eos>')
    os.remove('ptbdataset.zip')

    tokenizer = Tokenizer()
    tokenizer.fit_on_texts([train_raw])

    train_seq, valid_seq, test_seq = tokenizer.texts_to_sequences([train_raw,
                                                                   valid_raw,
                                                                   test_raw])
    return train_seq, valid_seq, test_seq


def load_ptb_data(num_steps=300, num_words=10000, batch_size=100, shuffle=True):
    train, valid, test = load_sequences(num_words=num_words)

    return (lambda: ptb_supervised(train, num_steps=num_steps,
                                   batch_size=batch_size, shuffle=shuffle),
            lambda: ptb_supervised(valid, num_steps=num_steps,
                                   batch_size=batch_size, shuffle=shuffle),
            lambda: ptb_supervised(test, num_steps=num_steps,
                                   batch_size=batch_size, shuffle=shuffle))


def entropy(y_true, y_pred):
    # y_true has shape (batch, steps, 1)
    # y_pred has shape (batch, steps, vocab)
    vocab_size = K.int_shape(y_pred)[-1]
    y_true = K.reshape(y_true, (-1,))
    y_pred = K.reshape(y_pred, (-1, vocab_size))
    return K.mean(K.sparse_categorical_crossentropy(y_true, y_pred))


def fit(model: Model,
        lr=0.001,
        lr_scheduler=None,
        epochs=300,
        batch_size=100,
        num_steps=300,
        num_words=10000,
        early_stopping=None):
    train, valid, test = load_ptb_data(num_steps=num_steps,
                                       num_words=num_words,
                                       batch_size=batch_size)
    n_train = len([_ for _ in train()])
    n_val = len([_ for _ in valid()])

    model.compile(loss=entropy,
                  optimizer=RMSprop(lr=lr))
    if lr_scheduler:
        callbacks = [LearningRateScheduler(lr_scheduler)]
    else:
        callbacks = None

    if early_stopping:
        if callbacks:
            callbacks.append(EarlyStopping(patience=early_stopping,
                                           mode='min',
                                           restore_best_weights=True))
        else:
            callbacks = [EarlyStopping(patience=early_stopping,
                                       mode='min',
                                       restore_best_weights=True)]

    def infinite_train():
        while True:
            yield from train()

    def infinite_valid():
        while True:
            yield from valid()

    history = model.fit_generator(infinite_train(),
                                  epochs=epochs,
                                  steps_per_epoch=n_train,
                                  validation_steps=n_val,
                                  validation_data=infinite_valid(),
                                  callbacks=callbacks,
                                  verbose=2).history
    return history


def evaluate(model: Model,
             batch_size=100,
             num_steps=300,
             num_words=10000):
    train, valid, test = load_ptb_data(num_steps=num_steps,
                                       num_words=num_words,
                                       batch_size=batch_size,
                                       shuffle=False)

    test_data = list(test())
    x_test = np.concatenate([e[0] for e in test_data])
    y_test = np.concatenate([e[1] for e in test_data])
    # flatten so that state is retained through computation
    x_test = x_test.reshape((1, -1))
    y_test = y_test.reshape((1, -1))

    # get the entropy average
    xe = model.evaluate(x_test, y_test)
    if type(xe) is list:
        xe = xe[0]

    # calculate perplexity
    ppl = float(np.exp(xe))
    return xe, ppl
