R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/activations/bert_activations_test01.py

"""
import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from transformers import TFBertForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.models import em_models
from em.util import monkey_patching

###############################################################################

model = em_models.from_pretrained("connectivity/feather_berts_0", from_pt=True)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

ds = em_datasets.load('glue/rte', split='train', sequence_length=64, tokenizer=tokenizer)

for x, y in ds.batch(2):
    break

output = model(x, training=False, output_hidden_states=True)

# The first corresponds to the embeddings, which will always be the same for the CLS token.
activations = output.hidden_states[1:]
cls_activations = [a[:, 0] for a in activations]

# 12 * 768 = 9216
