"""Runs and saves ica on BERT activations."""
import os

from absl import app
from absl import flags

import numpy as np
import tensorflow as tf

from em.activations import bert_activations
from em.tools.ica import tf_ica

###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_string("activations_path", None, "")
flags.DEFINE_string("output_path", None, "Path to h5 file to write output to.")

flags.DEFINE_integer("n_components", None, "")
flags.DEFINE_integer("max_iter", 500, "")

###############################################################################


def main(_):
    assert FLAGS.n_components is not None

    activations = bert_activations.BertClsActivations.load(FLAGS.activations_path).activations
    # Use only the representations from the last layer.
    activations = activations[:, -768:]
    activations /= np.sqrt(np.sum(activations**2, keepdims=True, axis=-1))
    activations = tf.cast(activations, tf.float32)

    ica = tf_ica.TfFastICA(
        n_components=FLAGS.n_components,
        n_features=activations.shape[-1],
        max_iter=FLAGS.max_iter,
        print_interval=10,
    )
    ica.fit(activations)

    ica.save(FLAGS.output_path)


if __name__ == "__main__":
    app.run(main)
