import tensorflow_datasets as tfds
import tensorflow as tf
from vat.vat_model import VATTransformer, CustomSchedule, create_masks, loss_function_vae
from vat.decoding import translate


def build_vocabulary():
    tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(
       (ex[1].numpy().decode('utf-8').strip() for ex in train_data), target_vocab_size=2 ** 15)
    # original tokens + start, end and separation token
    vocab_size = tokenizer.vocab_size + 3
    return tokenizer, vocab_size


def py_encode(de, en):
    en = en.numpy().decode('utf-8').strip()
    en = [tokenizer.vocab_size] + tokenizer.encode(en) + [tokenizer.vocab_size + 1]
    # replicated just in case we want to process the decoder input differently from the encoder input
    return en, en


def tf_encode(data_de, data_en):
    result_en1, result_en2 = tf.py_function(py_encode, [data_de, data_en], [tf.int64, tf.int64])
    result_en1.set_shape([None])
    return result_en1


def filter_max_length(x, max_length=100):
    return tf.size(x) <= max_length



train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')


step_signature = [
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(None, None), dtype=tf.int64),
    tf.TensorSpec(shape=(), dtype=tf.float32)
]

@tf.function(input_signature=step_signature)
def train_step(inp, tar_inp, tar_real, global_step):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, latent_mean_train, latent_logv_train, latent_code_train = \
                transformer(inp, tar_inp, True, enc_padding_mask, combined_mask, dec_padding_mask)
        l_rec_train, l_reg_train, reg_weight_train =\
            loss_function_vae(tar_real, predictions, latent_logv_train, latent_mean_train, global_step, STD_DEV)
        vae_loss_train = l_rec_train + l_reg_train * reg_weight_train

    gradients = tape.gradient(vae_loss_train, transformer.trainable_variables)
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(vae_loss_train)
    train_accuracy(tar_real, predictions)


if __name__ == '__main__':
    STD_DEV = 0.4
    NUM_EPOCHS = 3

    examples, info = tfds.load('wmt19_translate/de-en', with_info=True, as_supervised=True)
    train_data, test_data = examples['train'], examples['validation']

    # Tokenization takes a long time. Only run this once and store the vocab file.
    # tokenizer, vocab_size = build_vocabulary(train_data)
    # Alternatively: Load a stored vocab file
    tokenizer = tfds.features.text.SubwordTextEncoder.load_from_file('vocab_file')
    vocab_size = tokenizer.vocab_size + 3  # 3 additional tokens: start, end, separation

    train_dataset = train_data.map(tf_encode)
    train_dataset = train_dataset.filter(filter_max_length)
    train_dataset = train_dataset.cache()
    train_dataset = train_dataset.padded_batch(128, padded_shapes=([None]))
    train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

    transformer = VATTransformer(num_layers=4, d_model=128, num_heads=8, dff=512, vocab_size=vocab_size,
                                 pe_size=500, std_dev=STD_DEV, rate=0.1)

    learning_rate_fn = CustomSchedule(d_model=128)
    optimizer = tf.keras.optimizers.Adam(learning_rate_fn, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    step_counter = tf.Variable(0, name='global_step', trainable=False, dtype=tf.float32)
    for epoch in range(NUM_EPOCHS):
        for batch, data in enumerate(train_dataset):
            train_step(data, data[:, :-1], data[:, 1:], step_counter)  # enc_input, dec_input, dec_target
            step_counter.assign_add(1)

            if batch % 1000 == 0:
                print(f'Epoch {epoch + 1} Batch {batch} Loss {train_loss.result():.4f} Accuracy {train_accuracy.result():.4f}')

    sample = 'In this class, we will introduce several fundamental concepts.'
	# reconstruction
    translate(transformer=transformer, tokenizer=tokenizer, sentence=sample, variation=False)
	# variational sampling
    translate(transformer=transformer, tokenizer=tokenizer, sentence=sample, variation=True)
