R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/activations/resnet_activations_ica001.py

"""
import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.activations import resnet_activations
from em.tools.ica import tf_ica
# from em.tools.ica import bert_activations_ica

###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

ACTS_FILENAME = "resnet50.imagenet.train.20000ex.activations.h5"

###############################################################################

acts = resnet_activations.ResnetActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, ACTS_FILENAME))

# Normalization
acts.activations /= np.sqrt(np.sum(acts.activations**2, keepdims=True, axis=-1))
activations = tf.cast(acts.activations, tf.float32)


N_COMPONENTS = 64
MAX_ITER = 500
TOL = 1e-4
PRINT_INTERVAL = 1

ica = tf_ica.TfFastICA(
    n_components=N_COMPONENTS,
    n_features=activations.shape[-1],
    max_iter=MAX_ITER,
    tol=TOL,
    print_interval=PRINT_INTERVAL,
)
ica.fit(activations)

coeffs = ica.transform(activations).numpy()

component_index = 0
n_examples = 12
top_inds = np.argsort(-coeffs[:, component_index])[:n_examples]
print(acts.labels[top_inds])


def print_top_preds(component_index: int, n_examples: int = 12):
    top_inds = np.argsort(-coeffs[:, component_index])[:n_examples]
    print(np.argmax(acts.logits[top_inds], axis=-1))


for i in range(N_COMPONENTS):
    print_top_preds(i)

