R""""""
# run_bert_tcav_exp.py

from absl import app
from absl import flags

from importlib import reload
import os
import time
import h5py

import numpy as np
from sklearn.linear_model import LogisticRegression
import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from em.tools import k_means

from em.models import em_models
from em.tools.nmf import nmf_common
from em.tools.ica import tf_ica
from em.activations import bert_activations
from em.analysis.tcav import bert_tcav
from em.util import hdf5_util

from em.util.color_util import cu

DECOMPOSITION_TYPES = ['NPEFF', 'ICA', "K_MEANS"]

FLAGS = flags.FLAGS

##########################################################################

flags.DEFINE_string("decomposition_filepath", None, "")
flags.DEFINE_enum("decomposition_type", None, DECOMPOSITION_TYPES, "")

flags.DEFINE_string("activations_filepath", None, "")

flags.DEFINE_string("model", None, "")
flags.DEFINE_boolean("from_pt", False, "")
# Outputs
flags.DEFINE_string("output_filepath", None, "Path to file to write h5 output to.")

# Flags describing what components to run on.
flags.DEFINE_list("component_indices", None, 'Leave set to None to run on all components.')

flags.DEFINE_integer("n_runs", None, '')

flags.DEFINE_integer("n_top_examples", None, '')
flags.DEFINE_integer("n_negative_examples", None, '')
flags.DEFINE_integer("n_scoring_examples", None, '')

##########################################################################


def get_top_example_indices(activations):
    if FLAGS.decomposition_type == 'NPEFF':
        with h5py.File(os.path.expanduser(FLAGS.decomposition_filepath), "r") as f:
            W = hdf5_util.load_h5_ds(f['data/W'])
        return np.argsort(-W, axis=0)[:FLAGS.n_top_examples].T

    elif FLAGS.decomposition_type == 'ICA':
        ica = tf_ica.TfFastICA.load(os.path.expanduser(FLAGS.decomposition_filepath))
        coeffs = tf.matmul(activations - ica.mean, tf.transpose(ica.whitening)).numpy()
        coeffs = np.abs(coeffs)
        return np.argsort(-coeffs, axis=0)[:FLAGS.n_top_examples].T
        
    elif FLAGS.decomposition_type == 'K_MEANS':
        km = k_means.KMeans.load(os.path.expanduser(FLAGS.decomposition_filepath))
        coeffs = km.create_coeffs(activations)
        return np.argsort(-coeffs, axis=0)[:FLAGS.n_top_examples].T

    else:
        raise ValueError(FLAGS.decomposition_type)


def main(_):
    acts = bert_activations.BertClsActivations.load(FLAGS.activations_filepath)

    # Use only the representations from the last layer.
    activations = acts.activations[:, -768:]
    labels = acts.labels

    top_example_indices = get_top_example_indices(activations)

    if FLAGS.component_indices is None:
        component_indices = list(range(top_example_indices.shape[0]))
    else:
        component_indices = [int(c) for c in FLAGS.component_indices]

    assert len(component_indices) > 0


    model = em_models.from_pretrained(FLAGS.model, from_pt=FLAGS.from_pt)

    exp = bert_tcav.BertTcavExperiment(
        activations=activations,
        labels=labels,
        model=model,
        n_negative_examples=FLAGS.n_negative_examples,
        n_scoring_examples=FLAGS.n_scoring_examples,
    )

    with h5py.File(os.path.expanduser(FLAGS.output_filepath), "w") as f:

        data_grp = f.create_group("data")

        hdf5_util.save_h5_ds(data_grp, 'component_indices', np.array(component_indices, dtype=np.int32))

        for component_index in tqdm(component_indices):
            
            comp_exp = bert_tcav.BertTcavForComponent2(
                exp=exp,
                concept_example_indices=top_example_indices[component_index],
                n_runs=FLAGS.n_runs,
            )
            comp_exp.learn_cavs()
            scores = comp_exp.compute_per_run_scores()
            hdf5_util.save_h5_ds(data_grp, f'component_{component_index}_scores', scores)


if __name__ == "__main__":
    app.run(main)
