"""Runs and saves kmeans on BERT activations."""
import os

from absl import app
from absl import flags

import numpy as np
import tensorflow as tf

from em.activations import bert_activations
from em.tools import k_means

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string("activations_path", None, "")
flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_integer("n_components", None, "")
flags.DEFINE_integer("n_iter", 500, "")

###############################################################################


def main(_):
    assert FLAGS.n_components is not None

    activations = bert_activations.BertClsActivations.load(FLAGS.activations_path).activations
    # Use only the representations from the last layer.
    activations = activations[:, -768:]
    activations /= np.sqrt(np.sum(activations**2, keepdims=True, axis=-1))

    def input_fn():
        return tf.compat.v1.train.limit_epochs(
            tf.convert_to_tensor(activations, dtype=tf.float32),
            num_epochs=FLAGS.n_iter)

    kmeans = tf.compat.v1.estimator.experimental.KMeans(
        num_clusters=FLAGS.n_components,
        use_mini_batch=False,
    )

    kmeans.train(input_fn)

    km = k_means.KMeans(
        cluster_centers=kmeans.cluster_centers(),
    )
    km.save(FLAGS.output_path)


if __name__ == "__main__":
    app.run(main)
