R"""Makes latex for bert top top components.



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

CUDA_VISIBLE_DEVICES=0 python em/projects/baselines/make_bert_latex_svd.py

"""

import dataclasses
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.activations import bert_activations
from em.tools.ica import tf_ica
from em.projects.pi import bert_decomposition_container

###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

CLS_ACTS_FILENAME = "feather_berts_0.snli.train_skip_50k.50000ex.bert_cls_activations.h5"
ICA_FILENAME = "ica.128comps.feather_berts_0.snli.train.50000ex.bert_cls_activations.h5"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

acts = bert_activations.BertClsActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, CLS_ACTS_FILENAME))

activations = acts.activations
# Use only the representations from the last layer.
activations = activations[:, -768:]
activations /= np.sqrt(np.sum(activations**2, keepdims=True, axis=-1))
activations = tf.cast(activations, tf.float32)

ica = tf_ica.TfFastICA.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, ICA_FILENAME))
coeffs = tf.matmul(activations - ica.mean, tf.transpose(ica.whitening)).numpy()
# coeffs = np.abs(coeffs)

# Need this?
labels = (acts.labels + 1) % 3

N_TOP_EXAMPLES = 16

container = bert_decomposition_container.BertContainer(
    coeffs=coeffs,
    tokenizer=tokenizer,
    labels=labels,
    predicted_logits=acts.logits,
    input_ids=acts.input_ids,
    n_top_examples=N_TOP_EXAMPLES,
)

latex = container.generate_all_components_latex()

OUT_DIR = '/fruitbasket/users/m/tmp'

OUT_FILENAME = 'bert_activations_ica128_train_skip_50k_svd.tex'

with open(os.path.join(OUT_DIR, OUT_FILENAME), 'w') as f:
    f.write(latex)

R"""

make_pdf() {
    local filename=$1

    rsync -ra -e ssh \
        "m@mango.cs.unc.edu:/fruitbasket/users/m/tmp/${filename}.tex" \
        "$HOME/Downloads/${filename}.tex"

    xelatex -interaction=batchmode -output-directory=/tmp ~/Downloads/${filename}.tex
    mv /tmp/${filename}.pdf ~/Desktop/projects_data/extract_merge1/ll/pdfs/baselines
}

make_pdf bert_activations_ica128_train_skip_50k_svd


"""
