R"""Makes latex for bert top top components.



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

CUDA_VISIBLE_DEVICES=0 python em/projects/baselines/make_resnet_images.py


rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/imagenet_ica128/" \
    "$HOME/Desktop/projects_data/extract_merge1/ll/pdfs/baselines/imagenet_ica128/"


"""

import dataclasses
from importlib import reload
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.activations import resnet_activations
from em.tools.ica import tf_ica
from em.projects.imagenet import resnet_decomposition_container

###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

ACTS_FILENAME = "resnet50.imagenet.validation.30000ex.activations.h5"
ICA_FILENAME = "ica.128comps.resnet50.imagenet.train.20000ex.activations.h5"

###############################################################################

acts = resnet_activations.ResnetActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, ACTS_FILENAME))

activations = acts.activations
activations /= np.sqrt(np.sum(activations**2, keepdims=True, axis=-1))
activations = tf.cast(activations, tf.float32)

ica = tf_ica.TfFastICA.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, ICA_FILENAME))
coeffs = ica.transform(activations).numpy()

N_TOP_EXAMPLES = 16

container = resnet_decomposition_container.ResnetContainer(
    coeffs=coeffs,
    ds_split="validation",
    n_top_examples=N_TOP_EXAMPLES,
)


IMAGE_SIZE = 224
OUT_DIR = '/fruitbasket/users/m/tmp/imagenet_ica128'


def make_image(cc, component_index: int, n_cols: int, n_rows: int, extension: str = 'jpg'):
    Q = IMAGE_SIZE
    #
    ret = np.zeros([n_rows * Q, n_cols * Q, 3], dtype=np.float64)
    #
    imgs = cc.get_top_examples(component_index, n_cols * n_rows)
    for i, image in enumerate(imgs):
        col = i % n_cols
        row = i // n_cols
        ret[row * Q : (row + 1) * Q, col * Q : (col + 1) * Q, :] = image.image
    #
    filepath = os.path.join(OUT_DIR, f'comp{component_index}.{extension}')
    plt.imsave(filepath, ret)


GRID_KWARGS = {'n_cols': 8, 'n_rows': 4}
for component_index in range(container.coeffs.shape[-1]):
    make_image(container, component_index, extension='jpg', **GRID_KWARGS)
