R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/activations/bert_activations_ica_test002.py

"""
import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.activations import bert_activations
from em.tools.ica import tf_ica
from em.tools.ica import bert_activations_ica
from em.models import em_models
from em.fishers import per_example
from em.projects.pi import qqp_components_context as QCC


###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PEFS_FOR_PREDICTIONS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers'
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)


CLS_ACTS_FILENAME = "feather_berts_0.snli.train_skip_50k.50000ex.bert_cls_activations.h5"

###############################################################################

model = em_models.from_pretrained("connectivity/feather_berts_0", from_pt=True)
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

###############################################################################

pef = per_example.PerExampleFlatFishers.load(
    os.path.expanduser(PEFS_FOR_PREDICTIONS_PATH),
    n_examples=None,
    # This leads to the Fishers not being loaded, which ends up being much faster.
    start_fisher_index=0,
    end_fisher_index=0,
)
eval_ctx = QCC.EvaluationContext2.create_from_pefs(
    pef=pef,
    tokenizer=tokenizer,
    special_processing='HF_MNLI',
)

###############################################################################

FILENAME = "test_bert_cls_acts_ica001.h5"
FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, FILENAME)

layer_indices = list(range(12 - 4, 12))

ica = tf_ica.TfFastICA.load(FILEPATH)
# NOTE: Using ica.mixing instead of ica.components is intentional here.
# components.shape = [n_components, n_features]
components = ica.mixing.numpy().T
components /= np.sqrt(np.sum(components**2, keepdims=True, axis=-1))


acts = bert_activations.BertClsActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, CLS_ACTS_FILENAME))
acts.activations /= np.sqrt(np.sum(acts.activations**2, keepdims=True, axis=-1))
coeffs = ica.transform(tf.cast(acts.activations[:, -ica.n_features:], tf.float32)).numpy()

###############################################################################


def set_component(component_index: int):
    p = components[component_index]
    p = np.reshape(p, [len(layer_indices), -1])
    for v, a in zip(ctx.component_variables, p):
        v.assign(a)


def run_for_component(component_index: int):
    set_component(component_index)
    with ctx.monkey_patcher:
        # TODO: Top by magnitude or top by value?
        top_inds = np.argsort(-coeffs[:, component_index])[:N_TOP_EXAMPLES]
        top_results = eval_ctx.evaluate(model, top_inds)
        total_results = eval_ctx.evaluate(model, np.arange(N_TOTAL_EXAMPLES))
    print(top_results.kl())
    print(total_results.kl())
    print(top_results.kl() / total_results.kl())


def run_for_component_ratio(component_index: int):
    set_component(component_index)
    with ctx.monkey_patcher:
        # TODO: Top by magnitude or top by value?
        top_inds = np.argsort(-coeffs[:, component_index])[:N_TOP_EXAMPLES]
        top_results = eval_ctx.evaluate(model, top_inds)
        total_results = eval_ctx.evaluate(model, np.arange(N_TOTAL_EXAMPLES))
    return top_results.kl() / total_results.kl()


def run_for_component2(component_index: int, magnitude: float):
    ctx.magnitude_variable.assign(magnitude)
    r1 = run_for_component_ratio(component_index)
    ctx.magnitude_variable.assign(-magnitude)
    r2 = run_for_component_ratio(component_index)
    return max(r1, r2)


###############################################################################

N_TOTAL_EXAMPLES = 1014
N_TOP_EXAMPLES = 128


reload(bert_activations_ica)
ctx = bert_activations_ica.ClsTokenPerturbationContext(
    d_model=model.config.hidden_size,
    layer_indices=layer_indices,
)

# ctx.magnitude_variable.assign(5e0)

# run_for_component(13)

mp = 5e0
for i in range(ica.n_components):
    print(run_for_component2(i, mp))


#
