R"""Learns a classifier on top of a frozen pretrained body.

NOTE: RoBERTa seems to be a lot worse for this than BERT. I'm just supporting
BERT models for now.

We save as a HF model even though this is quite inefficient because it
simplifies some downstream code.
"""
import os
import time

from absl import app
from absl import flags
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from transformers import TFBertPreTrainedModel

from em import datasets as em_datasets
from em.util import hf_util
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS

flags.DEFINE_string("output_path", None, "Path to folder that will be created containing HF information.")

flags.DEFINE_string("pretrained_model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("task", None, "")

flags.DEFINE_string("tokenizer", None, "Defaults to `pretrained_model` if not set.")

flags.DEFINE_string("split", None, "Uses a default if not set.")
flags.DEFINE_string("val_split", None, "Uses a default if not set.")

flags.DEFINE_integer("n_cache_train_examples",
                     None, 
                     "If this is larger than the number of train set examples, then will use that instead.")
flags.DEFINE_integer("cache_batch_size", 512, "Batch size to use when caching training examples.")

flags.DEFINE_integer("n_val_examples", 2048, "")

flags.DEFINE_float("learning_rate", 1e-3, "")
flags.DEFINE_float("clipnorm", None, "")
flags.DEFINE_integer("sequence_length", 128, "")

flags.DEFINE_integer("batch_size", 256, "")
flags.DEFINE_integer("n_epochs", None, "")
flags.DEFINE_integer("steps_per_epoch", None, "")


def load_pretrained_model():
    model_str = os.path.expanduser(FLAGS.pretrained_model)
    num_labels = em_datasets.n_classes_for_task(FLAGS.task)

    if FLAGS.tokenizer is None:
        tokenizer = AutoTokenizer.from_pretrained(model_str)
    else:
        tokenizer = AutoTokenizer.from_pretrained(os.path.expanduser(FLAGS.tokenizer))

    model = TFAutoModelForSequenceClassification.from_pretrained(
        model_str, from_pt=FLAGS.from_pt, num_labels=num_labels
    )
    assert isinstance(model, TFBertPreTrainedModel), 'TODO: Support models other than BERT.'
    
    return model, tokenizer


def load_datasets_to_cache(tokenizer):
    task = FLAGS.task

    train_ds = em_datasets.load(
        task,
        split=FLAGS.split or 'train',
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
    )
    train_ds = train_ds.take(FLAGS.n_cache_train_examples).batch(FLAGS.cache_batch_size)

    val_split = FLAGS.val_split or em_datasets.de_facto_validation_split(task)
    val_ds = em_datasets.load(
        task,
        split=val_split,
        tokenizer=tokenizer,
        sequence_length=FLAGS.sequence_length,
    )
    val_ds = val_ds.take(FLAGS.n_val_examples).batch(FLAGS.cache_batch_size)
    return train_ds, val_ds


def cache_poolings(body, ds):
    print('Starting to cache poolings.')
    pooler_outputs = []
    labels = []
    start = time.time()
    for x, y in ds:
        pooler_output = body(x, training=False).pooler_output
        assert pooler_output is not None
        pooler_outputs.append(pooler_output)
        labels.append(y)
    print('Caching poolings time taken:', time.time() - start)
    pooler_outputs = tf.concat(pooler_outputs, axis=0)
    labels = tf.concat(labels, axis=0)
    return pooler_outputs, labels


def main(_):
    model, tokenizer = load_pretrained_model()
    body, head = hf_util.get_body_and_head(model)
    train_ds, val_ds = load_datasets_to_cache(tokenizer)

    train_ds = cache_poolings(body, train_ds)
    train_ds = tf.data.Dataset.from_tensor_slices(train_ds)
    train_ds = train_ds.repeat().shuffle(1000).batch(FLAGS.batch_size)

    val_ds = cache_poolings(body, val_ds)
    val_ds = tf.data.Dataset.from_tensor_slices(val_ds)
    val_ds = val_ds.batch(FLAGS.batch_size)

    classifier = tf.keras.Sequential([head])

    classifier.compile(
        optimizer=tf.keras.optimizers.Adam(FLAGS.learning_rate, clipnorm=FLAGS.clipnorm),
        metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )

    history = classifier.fit(
        train_ds,
        steps_per_epoch=FLAGS.steps_per_epoch,
        epochs=FLAGS.n_epochs,
        validation_data=val_ds,
    )
    history = history.history

    model.save_pretrained(os.path.expanduser(FLAGS.output_path))


if __name__ == "__main__":
    app.run(main)
