R"""Computes and saves BERT CLS token activations."""
import os

from absl import app
from absl import flags

from em import datasets as em_datasets
from em.models import em_models
from em.activations import bert_activations
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("tokenizer", None, "Defaults to the value of --model if not set.")

flags.DEFINE_string("task", None, "")
#
flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("batch_size", 16, "")
flags.DEFINE_integer("sequence_length", 128, "Note that this is used to set image sizes as well.")
#
flags.DEFINE_bool("shuffle", False, "")
flags.DEFINE_integer("skip", None, "")

flags.DEFINE_bool("ds_force_deterministic", False, "Only has effects for some datasets.")

###############################################################################


def create_dataset(tokenizer):
    if FLAGS.task.startswith('winogrande/') and FLAGS.ds_force_deterministic:
        extra_kwargs = {'force_deterministic': True}
    else:
        extra_kwargs = {}

    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=tokenizer,
            sequence_length=FLAGS.sequence_length,
            **extra_kwargs,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)

    if FLAGS.shuffle:
        ds = ds.shuffle(1000)
    ds = ds.repeat().take(FLAGS.n_examples).cache()
    ds = ds.batch(FLAGS.batch_size)
    return ds


def main(_):
    model_str = os.path.expanduser(FLAGS.model)
    model = em_models.from_pretrained(model_str, from_pt=FLAGS.from_pt)

    tokenizer = em_models.load_tokenizer(FLAGS.tokenizer or model_str)
    ds = create_dataset(tokenizer)

    computer = bert_activations.ClsActivationsComputer(
        model=model,
    )

    saver = bert_activations.StreamingClsActivationsSaver(
        computer=computer,
        n_examples=FLAGS.n_examples,
        sequence_length=FLAGS.sequence_length,
        use_tqdm=True,
    )

    # TODO: Maybe add some safety in case n_examples is greater than the size of the dataset?
    output_path = os.path.expanduser(FLAGS.output_path)
    saver.compute_and_save_activations(output_path, ds)


if __name__ == "__main__":
    app.run(main)
