R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/activations/bert_activations_ica_test001.py

"""
import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.activations import bert_activations
from em.tools.ica import tf_ica
from em.tools.ica import bert_activations_ica

###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

CLS_ACTS_FILENAME = "feather_berts_0.snli.train_skip_50k.50000ex.bert_cls_activations.h5"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

acts = bert_activations.BertClsActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, CLS_ACTS_FILENAME))

# Normalization
acts.activations /= np.sqrt(np.sum(acts.activations**2, keepdims=True, axis=-1))

activations = tf.cast(acts.activations, tf.float32)


N_FEATURES = 4 * 768
N_EXAMPLES = 50_000
activations = activations[:N_EXAMPLES, -N_FEATURES:]

N_COMPONENTS = 64
# MAX_ITER = 200
MAX_ITER = 500
TOL = 1e-4
PRINT_INTERVAL = 1


FILENAME = "test_bert_cls_acts_ica001.h5"
FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, FILENAME)


# ica = tf_ica.TfFastICA(
#     n_components=N_COMPONENTS,
#     n_features=activations.shape[-1],
#     max_iter=MAX_ITER,
#     tol=TOL,
#     print_interval=PRINT_INTERVAL,
# )
# ica.fit(activations)
# ica.save(FILEPATH)
ica = tf_ica.TfFastICA.load(FILEPATH)


coeffs = ica.transform(tf.cast(acts.activations[:, -N_FEATURES:], tf.float32)).numpy()


def something(component_index: int, n_examples: int = 12):
    top_inds = np.argsort(-coeffs[:, component_index])[:n_examples]
    input_ids = acts.input_ids[top_inds]
    preds = np.argmax(acts.logits[top_inds], axis=-1)
    for pred, inputs in zip(preds, input_ids):
        s = tokenizer.decode(inputs.astype(np.int32))
        s = s.replace(tokenizer.pad_token, '')
        s = s.strip()
        print(pred, s)


# # something(1)
# for i in range(N_COMPONENTS):
#     something(i)
#     print(2 * '\n')

