R"""Makes latex for resnet top top components.



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

CUDA_VISIBLE_DEVICES= python em/projects/baselines/make_resnet_images_kmeans.py


rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/imagenet_ica128_kmeans/" \
    "$HOME/Desktop/projects_data/extract_merge1/ll/pdfs/baselines/imagenet_ica128_kmeans/"


"""

import dataclasses
from importlib import reload
import os

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.activations import resnet_activations
from em.tools import k_means
from em.projects.imagenet import resnet_decomposition_container

###############################################################################

# Needed for some reason to prevent BLAS fail to launch.
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/m_npeff1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

ACTS_FILENAME = "resnet50.imagenet.validation.30000ex.activations.h5"
KMEANS_FILENAME = "kmeans.128comps.resnet50.imagenet.train.20000ex.activations.h5"

###############################################################################

acts = resnet_activations.ResnetActivations.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, ACTS_FILENAME))

activations = acts.activations
activations /= np.sqrt(np.sum(activations**2, keepdims=True, axis=-1))

km = k_means.KMeans.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, KMEANS_FILENAME))
coeffs = km.create_coeffs(activations)

N_TOP_EXAMPLES = 16

container = resnet_decomposition_container.ResnetContainer(
    coeffs=coeffs,
    ds_split="validation",
    n_top_examples=N_TOP_EXAMPLES,
)


IMAGE_SIZE = 224
OUT_DIR = '/fruitbasket/users/m/tmp/imagenet_ica128_kmeans'


def make_image(cc, component_index: int, n_cols: int, n_rows: int, extension: str = 'jpg'):
    Q = IMAGE_SIZE
    #
    ret = np.zeros([n_rows * Q, n_cols * Q, 3], dtype=np.float64)
    #
    imgs = cc.get_top_examples(component_index, n_cols * n_rows)
    for i, image in enumerate(imgs):
        col = i % n_cols
        row = i // n_cols
        ret[row * Q : (row + 1) * Q, col * Q : (col + 1) * Q, :] = image.image
    #
    filepath = os.path.join(OUT_DIR, f'comp{component_index}.{extension}')
    plt.imsave(filepath, ret)


GRID_KWARGS = {'n_cols': 8, 'n_rows': 4}
for component_index in range(container.coeffs.shape[-1]):
    make_image(container, component_index, extension='jpg', **GRID_KWARGS)
