R"""Computes and saves Resnet activations."""
import os

from absl import app
from absl import flags
import tensorflow as tf

from em import datasets as em_datasets
from em.models import em_models
from em.activations import resnet_activations
from em.util import vat_da_faak_vpn
###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_string("model", None, "")
flags.DEFINE_string("task", None, "")
#
flags.DEFINE_list("split", ["train"], "If multiple splits are provided, will run on their concatenation.")
flags.DEFINE_integer("n_examples", None, "")
flags.DEFINE_integer("batch_size", 16, "")
#
flags.DEFINE_bool("shuffle", False, "")
flags.DEFINE_integer("skip", None, "")

###############################################################################


def create_dataset():
    ds = None

    for split in FLAGS.split:
        split_ds = em_datasets.load(
            FLAGS.task,
            split=split,
            tokenizer=None,
            sequence_length=224,
        )
        if ds is None:
            ds = split_ds
        else:
            ds = ds.concatenate(split_ds)
    
    if FLAGS.skip is not None:
        ds = ds.skip(FLAGS.skip)

    if FLAGS.shuffle:
        ds = ds.shuffle(1000)
    ds = ds.repeat().take(FLAGS.n_examples).cache()
    ds = ds.batch(FLAGS.batch_size)
    return ds

###############################################################################


def main(_):
    ds = create_dataset()

    saver = resnet_activations.StreamingActivationsSaver(
        model=resnet_activations.create_activations_model(FLAGS.model),
        n_examples=FLAGS.n_examples,
        d_activations=2048,
        n_classes=1000,
        use_tqdm=True,
    )

    # TODO: Maybe add some safety in case n_examples is greater than the size of the dataset?
    output_path = os.path.expanduser(FLAGS.output_path)
    saver.compute_and_save_activations(output_path, ds)


if __name__ == "__main__":
    app.run(main)
