"""
## Setup
"""
import os
os.environ["KERAS_BACKEND"] = "jax"  # @param ["tensorflow", "jax", "torch"]

import time

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.callbacks import Callback
import pandas as pd
from src.standard_attention import StandardMultiHeadAttention
from src.optimised_attention import OptimisedAttention
from src.efficient_attention import EfficientAttention
from src.super_attention import SuperAttention
# super without wo
# standard with w0

# Using GPU 1 only
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

"""
## Implement a Transformer block as a layer
"""


class TransformerBlock(layers.Layer):
    def __init__(self, ATTENTION_ARCH, projection_dim, num_heads, ff_dim, rate=0.2):
        super().__init__()
        self.att = ATTENTION_ARCH(num_heads=num_heads, key_dim=(projection_dim//num_heads))
        self.ffn = keras.Sequential(
            [
                layers.Dense(ff_dim, activation="relu"),
                layers.Dense(projection_dim),
            ]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)


"""
## Implement TimeHistory Callback
"""

class TimeHistory(Callback):
    def on_train_begin(self, logs={}):
        self.times = []

    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_start_time = time.time()

    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_start_time)


"""
## Implement embedding layer

Two seperate embedding layers, one for tokens, one for token index (positions).
"""


class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, projection_dim):
        super().__init__()
        self.maxlen = maxlen
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=projection_dim)
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=projection_dim)

    def build(self, input_shape):
        super(TokenAndPositionEmbedding, self).build(input_shape)

    def call(self, x):
        positions = tf.range(start=0, limit=self.maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions


"""
## Create classifier model using transformer layer

Transformer layer outputs one vector for each time step of our input sequence.
Here, we take the mean across all time steps and
use a feed forward network on top of it to classify text.
"""


# implement the model as a function
def create_transformer_classifier(attention_arch, maxlen, vocab_size, projection_dim, num_heads, ff_dim):
    inputs = layers.Input(shape=(maxlen,))
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, projection_dim)
    x = embedding_layer(inputs)
    transformer_block = TransformerBlock(attention_arch, projection_dim, num_heads, ff_dim)
    x = transformer_block(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dropout(0.1)(x)
    x = layers.Dense(6, activation="relu")(x)
    x = layers.Dropout(0.1)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model
"""
## Train and Evaluate
"""

def run_experiment(attention_arch, maxlen, vocab_size, projection_dim, num_heads, ff_dim, run_number):
    model = create_transformer_classifier(attention_arch, maxlen, vocab_size, projection_dim, num_heads, ff_dim)
    model.compile(
        optimizer="adam", loss="BCE", metrics=["accuracy"]
    )
    arch_name = attention_arch.__name__
    checkpoint_filepath = "./results/imdb/model/" + arch_name + f"_{num_heads}_heads" + "/run_num_" + str(
        run_number) + "/" + "model.weights.h5"
    history_filepath = "./results/imdb/history/" + arch_name + f"_{num_heads}_heads" + "/run_num_" + str(
        run_number) + "/history" + ".csv"
    test_history_filepath = ("./results/imdb/history/" + arch_name + f"_{num_heads}_heads" +
                             "/run_num_" + str(run_number) + "/test_history.csv")
    general_info_filepath = "./results/imdb/history/" + arch_name + f"_{num_heads}_heads" + "/run_num_" + str(
        run_number) + "/general_info" + ".csv"
    # create the directories if not exist
    os.makedirs(os.path.dirname(checkpoint_filepath), exist_ok=True)
    os.makedirs(os.path.dirname(history_filepath), exist_ok=True)
    os.makedirs(os.path.dirname(test_history_filepath), exist_ok=True)
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    # print(model.summary())

    time_callback = TimeHistory()
    history = model.fit(
        x=x_train, y=y_train,
        batch_size=64, epochs=10, validation_split=0.1,
        callbacks=[time_callback, checkpoint_callback]
    )

    # Save History till 5 decimal places
    history_df = pd.DataFrame(history.history)
    history_df = history_df.round(5)
    history_df.to_csv(history_filepath, sep='\t', index=False)

    model.load_weights(checkpoint_filepath)
    loss, accuracy = model.evaluate(x_test, y_test)

    # Save the test results with 5 decimal places
    test_history_df = pd.DataFrame([[round(loss, 5), round(accuracy, 5)]],
                                   columns=["loss", "accuracy"])
    # print the headers in the first row and then the values in the second row
    test_history_df.to_csv(test_history_filepath, sep='\t', index=False)
    # save to general info file number of attention parameters and
    num_of_attention_params = model.layers[2].att.count_params()
    average_epoch_time_without_first_epoch = sum(time_callback.times[1:]) / (len(time_callback.times) - 1)
    average_epoch_time = sum(time_callback.times) / len(time_callback.times)
    # save to general info file with first row as header and second row as values with 3 decimal places
    general_info_pd = pd.DataFrame([[num_of_attention_params,
                                     round(average_epoch_time_without_first_epoch, 3), round(average_epoch_time, 3)]],
                                   columns=["num_of_attention_params",
                                            "average_epoch_time_excluding_first_epoch", "average_epoch_time"])
    general_info_pd.to_csv(general_info_filepath, sep='\t', index=False)

    return history

"""
## Download and prepare dataset
"""

vocab_size = 20000  # Only consider the top 20k words
maxlen = 32  # Only consider the first 32 words of each movie review
projection_dim = 32
ff_dim = 32  # Hidden layer size in feed forward network inside transformer
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=vocab_size)
print(len(x_train), "Training sequences")
print(len(x_test), "Test sequences")
x_train = keras.utils.pad_sequences(x_train, maxlen=maxlen)
x_test = keras.utils.pad_sequences(x_test, maxlen=maxlen)

# Different Archihtectures for Attention

ATTENTION_ARCHS = [EfficientAttention, SuperAttention, OptimisedAttention, StandardMultiHeadAttention]
NUM_OF_HEADS = [4, 2, 1]
NUM_OF_RUNS = 5

for run_number in range(NUM_OF_RUNS):
    for attention_arch in ATTENTION_ARCHS:
        for num_of_heads in NUM_OF_HEADS:
            print("Running for attention_arch: ", attention_arch.__name__, " run_number: ", run_number)
            run_experiment(attention_arch, maxlen, vocab_size, projection_dim, num_of_heads, ff_dim, run_number)

