import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf
import pickle

"""
running experiments for IMDB pretrained representation for 10 iterations
the embedding (last hidden layer) is saved after each training epoch
"""


def main(log_dir, emb_dim, save_emb, epochs, lr): 
    
    dataset, info = tfds.load('imdb_reviews', with_info=True,
                              as_supervised=True)
    train_dataset, test_dataset = dataset['train'], dataset['test']

    BUFFER_SIZE = 10000
    BATCH_SIZE = 64

    train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
    test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

    VOCAB_SIZE = 1000
    encoder = tf.keras.layers.TextVectorization(
        max_tokens=VOCAB_SIZE)
    encoder.adapt(train_dataset.map(lambda text, label: text))

    vocab = np.array(encoder.get_vocabulary())

    dataset_numpy = tfds.as_numpy(dataset['train'])

    pretrn_list = []
    for iter_ in range(10):
        keras.backend.clear_session()
        model = tf.keras.Sequential([
            encoder,
            tf.keras.layers.Embedding(
                input_dim=len(encoder.get_vocabulary()),
                output_dim=16,
                # Use masking to handle the variable sequence lengths
                mask_zero=True),
            tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(16)),
            tf.keras.layers.Dense(emb_dim, activation='relu'),
            tf.keras.layers.Dense(1)
        ])
        model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(lr),
                  metrics=['accuracy'])
        
        for epoch_ in range(epochs):
            history = model.fit(train_dataset, epochs=1, verbose = 0)
            pretrn_list.append(history)
            layer_name = "dense"
            intermediate_layer_model = keras.Model(inputs=model.input,
                                                   outputs=model.get_layer(layer_name).output)
            hidden_repr = []
            for ds in dataset_numpy:
                hidden_repr.append(intermediate_layer_model(np.array([ds[0]])))

            with open(log_dir + 'IMDB_repr_epoch{}_iter{}.pkl'.format(epoch_, iter_), 'wb') as f:
                pickle.dump(np.array(hidden_repr), f)
            print("iter {} done".format(iter_))
    
    

    
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_dir", type=str, default="log/IMDB/")
    parser.add_argument("--emb_dim", type=int, default=32)
    parser.add_argument("--epochs", type=int, default=20)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--save_emb", type=bool, default=True) # whether to save trained embeddings
    parser.add_argument("--GPU", type=int, default=1)
    
    args = parser.parse_args()
    
    # setting up GPU
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            tf.config.experimental.set_memory_growth(gpus[args.GPU], True)
            tf.config.experimental.set_visible_devices(gpus[args.GPU], 'GPU')
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    #             print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
        except RuntimeError as e:
            print(e)
    
    main(args.log_dir, args.emb_dim, args.save_emb, args.epochs, args.lr) 