R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/activations/tf_kmeans_test002.py

"""
import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf

from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.activations import bert_activations

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

CLS_ACTS_FILENAME = "feather_berts_0.snli.train.50000ex.bert_cls_activations.h5"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

acts = bert_activations.BertClsActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, CLS_ACTS_FILENAME))

activations = acts.activations
# Use only the representations from the last layer.
activations = activations[:, -768:]
activations /= np.sqrt(np.sum(activations**2, keepdims=True, axis=-1))
# activations = tf.cast(activations, tf.float32)

###############################################################################

# N_EXAMPLES = 10_000
N_EXAMPLES = 50_000

N_ITERS = 50


def input_fn():
    return tf.compat.v1.train.limit_epochs(
        tf.convert_to_tensor(activations[:N_EXAMPLES], dtype=tf.float32),
        num_epochs=N_ITERS)


# num_clusters = 32
num_clusters = 128
kmeans = tf.compat.v1.estimator.experimental.KMeans(
    num_clusters=num_clusters,
    use_mini_batch=False,
)

kmeans.train(input_fn)
cluster_centers = kmeans.cluster_centers()

cluster_indices = list(kmeans.predict_cluster_index(lambda: tf.compat.v1.train.limit_epochs(tf.convert_to_tensor(activations), num_epochs=1)))


kmeans2 = tf.compat.v1.estimator.experimental.KMeans(
    num_clusters=num_clusters,
    use_mini_batch=False,
    initial_clusters=cluster_centers,
)
cluster_indices2 = list(kmeans2.predict_cluster_index(lambda: tf.compat.v1.train.limit_epochs(tf.convert_to_tensor(activations), num_epochs=1)))


print(np.all(np.array(cluster_indices) == np.array(cluster_indices2)))


# num_iterations = 10
# previous_centers = None
# for _ in range(num_iterations):
#     kmeans.train(input_fn)
#     cluster_centers = kmeans.cluster_centers()
#     if previous_centers is not None:
#         print('delta:', cluster_centers - previous_centers)
#     previous_centers = cluster_centers
#     print('score:', kmeans.score(input_fn))
# print('cluster centers:', cluster_centers)

# # map the input points to their clusters
# cluster_indices = list(kmeans.predict_cluster_index(input_fn))
# for i, point in enumerate(points):
#     cluster_index = cluster_indices[i]
#     center = cluster_centers[cluster_index]
#     print('point:', point, 'is in cluster', cluster_index, 'centered at', center)
