# make_comp_top_exs001.py
R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python local_scripts/m_npeff/imagenet2/make_comp_top_exs001.py

"""
import dataclasses
from importlib import reload
import os
from typing import List, Optional, Sequence

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf

from em import datasets as em_datasets
from em.datasets.imagenet import imagenet_x

from em.models import em_models

from em.fishers import per_example
from em.tools.nmf import nmf_common

from em.projects.imagenet import imagenet_components_context as imagenet_cc

LrmImageNetCC = imagenet_cc.LrmImageNetCC

###############################################################################

IMAGE_SIZE = 224

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/imagenet2_lrm_npeff'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

# PEF_FILENAME = "resnet50_imagenet.train.20000ex.65536.mpc3e-3.35mc.h5"
# NMF_FILENAME = "resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.h5"
# SPLIT = "train"
# OUT_DIR = '/fruitbasket/users/m/imagenet_lvrm_npeff_save_bin1'


PEF_FILENAME = "resnet50_imagenet.validation.30000ex.65536.mpc3e-3.35mc.h5"
NMF_FILENAME = "resnet50_imagenet.train.20000ex.65536.mpc3e-3.mnpeff.512comps.001.fit_to_validation.001.h5"
SPLIT = "validation"
OUT_DIR = '/fruitbasket/users/m/imagenet_lvrm_npeff_save_bin.validation.1'

##########################################################################


def make_image(cc: LrmImageNetCC, component_index: int, n_cols: int, n_rows: int, extension: str = 'png'):
    Q = IMAGE_SIZE
    #
    ret = np.zeros([n_rows * Q, n_cols * Q, 3], dtype=np.float64)
    #
    imgs = cc.get_top_examples(component_index, n_cols * n_rows)
    for i, image in enumerate(imgs):
        col = i % n_cols
        row = i // n_cols
        ret[row * Q : (row + 1) * Q, col * Q : (col + 1) * Q, :] = image.image
    #
    filepath = os.path.join(OUT_DIR, f'comp{component_index}.{extension}')
    plt.imsave(filepath, ret)


##########################################################################

if not os.path.exists(OUT_DIR):
    os.mkdir(OUT_DIR)


cc = LrmImageNetCC(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    ds_split=SPLIT,
)


GRID_KWARGS = {'n_cols': 12, 'n_rows': 8}
for component_index in range(cc.nmf.W.shape[-1]):
    make_image(cc, component_index, extension='jpg', **GRID_KWARGS)


R"""

rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/imagenet_lvrm_npeff_save_bin1" \
    "$HOME/Desktop/projects_data/extract_merge1/imagenet1/save_bin/"


rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/imagenet_lvrm_npeff_save_bin.validation.1" \
    "$HOME/Desktop/projects_data/extract_merge1/imagenet1/save_bin/"

"""
