R"""Makes the top example stuff to include in paper.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i em/projects/icml2023/plots/make_imagenet_main_top_examples.py



rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/comp36_top_ex_main.png" \
    "$HOME/Desktop/projects_data/extract_merge1/paper1_imgs/"


rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/comp218_top_ex_main.png" \
    "$HOME/Desktop/projects_data/extract_merge1/paper1_imgs/"




"""
import os

import dataclasses
from importlib import reload

from typing import List, Optional, Sequence

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf

from em import datasets as em_datasets
from em.datasets.imagenet import imagenet_x

from em.models import em_models

from em.fishers import per_example
from em.tools.nmf import nmf_common

from em.projects.imagenet import imagenet_components_context as imagenet_cc

ImageNetCC = imagenet_cc.ImageNetCC


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/imagenet1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

TOKENIZER = 'bert-base-uncased'

##########################################################################

# PEF_FILENAME = "resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
# NMF_FILENAME = f"spH.nmf_decomp.c{512}_{2500}Iters_{65536}pe_mvpp{6}_{20000}ex.{PEF_FILENAME}"
# SPLIT = "train"


OG_PEF_FILENAME = "resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
OG_NMF_FILENAME = f"spH.nmf_decomp.c{512}_{2500}Iters_{65536}pe_mvpp{6}_{20000}ex.{OG_PEF_FILENAME}"

PEF_FILENAME = "resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"
NMF_FILENAME = f"fit_w.{65536}vpe.{OG_NMF_FILENAME}"
SPLIT = "validation"


##########################################################################
OUT_DIR = '/fruitbasket/users/m/tmp'

IMAGE_SIZE = 224
SEP_SIZE = 3
##########################################################################


def make_image(cc: ImageNetCC, component_index: int, n_cols: int, n_top_rows: int, bottom_row_indices: Sequence[int]):
    assert len(bottom_row_indices) == n_cols
    #
    ret = np.zeros([(n_top_rows + 1) * IMAGE_SIZE, n_cols * IMAGE_SIZE, 3], dtype=np.float64)
    #
    imgs = cc.get_top_examples(component_index, 8 * 12)
    top_images = imgs[:n_top_rows * n_cols]
    bottom_images = [imgs[idx] for idx in bottom_row_indices]
    #
    all_images = [*top_images, *bottom_images]
    #
    #
    Q = IMAGE_SIZE
    for i, image in enumerate(all_images):
        col = i % n_cols
        row = i // n_cols
        ret[row * Q : (row + 1) * Q, col * Q : (col + 1) * Q, :] = image.image
    #
    sep = np.ones([SEP_SIZE, n_cols * IMAGE_SIZE, 3], dtype=np.float64)
    ret = np.concatenate([
        ret[:Q * n_top_rows],
        sep,
        ret[-Q:],
    ], axis=0)
    return ret


##########################################################################

cc = ImageNetCC(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    ds_split=SPLIT,
)

GRID_KWARGS = {'n_cols': 4, 'n_top_rows': 2}


COMP36_BRIS = [12 + 2, 12 + 3, 24 + 1, 24 + 3]
img = make_image(cc, component_index=36, bottom_row_indices=COMP36_BRIS, **GRID_KWARGS)
#
filepath = os.path.join(OUT_DIR, f'comp{36}_top_ex_main.png')
plt.imsave(filepath, img)
#
# plt.imshow(img)
# plt.show()


COMP218_BRIS = [24 + 1, 24 + 2, 24 + 4, 24 + 5]
img = make_image(cc, component_index=218, bottom_row_indices=COMP218_BRIS, **GRID_KWARGS)
#
filepath = os.path.join(OUT_DIR, f'comp{218}_top_ex_main.png')
plt.imsave(filepath, img)
#
# plt.imshow(img)
# plt.show()


#
# plt.imsave('name.png', array)