R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/activations/bert_activations_kmeans_test001.py

"""
import dataclasses
from importlib import reload
import os

import numpy as np
# from sklearn.cluster import KMeans
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.activations import bert_activations
from em.tools.ica import tf_ica
from em.tools.ica import bert_activations_ica

from em.tools.k_means import KMeans

###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

CLS_ACTS_FILENAME = "feather_berts_0.snli.train_skip_50k.50000ex.bert_cls_activations.h5"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

acts = bert_activations.BertClsActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, CLS_ACTS_FILENAME))

acts.activations = acts.activations[:, -768:]
# Normalization
acts.activations /= np.sqrt(np.sum(acts.activations**2, keepdims=True, axis=-1))
activations = tf.cast(acts.activations, tf.float32)

N_COMPONENTS = 64

# kmeans = KMeans(n_clusters=N_COMPONENTS).fit(acts.activations)
kmeans = KMeans(k=N_COMPONENTS, n_iter=10)

print('starting fit')
clusters = kmeans.fit_predict(activations)
