R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/imagenet/comp_explore01.py
CUDA_VISIBLE_DEVICES= python -i local_scripts/imagenet/comp_explore01.py

"""
import dataclasses
from importlib import reload
import os
from typing import List, Optional, Sequence

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf

from em import datasets as em_datasets
from em.datasets.imagenet import imagenet_x

from em.models import em_models

from em.fishers import per_example
from em.tools.nmf import nmf_common


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/imagenet1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

TOKENIZER = 'bert-base-uncased'

##########################################################################

# PEF_FILENAME = "resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
# NMF_FILENAME = f"spH.nmf_decomp.c{512}_{2500}Iters_{65536}pe_mvpp{6}_{20000}ex.{PEF_FILENAME}"
# SPLIT = "train"


OG_PEF_FILENAME = "resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
OG_NMF_FILENAME = f"spH.nmf_decomp.c{512}_{2500}Iters_{65536}pe_mvpp{6}_{20000}ex.{OG_PEF_FILENAME}"

PEF_FILENAME = "resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"
NMF_FILENAME = f"fit_w.{65536}vpe.{OG_NMF_FILENAME}"
SPLIT = "validation"


##########################################################################


@dataclasses.dataclass
class ImageNetExample:
    label: int
    image: np.ndarray


@dataclasses.dataclass
class ImageNetCC:
    """The CC stands for components context."""

    pef_filepath: str
    nmf_filepath: str

    ds_split: str
    ds_offset: int = 0

    def __post_init__(self):
        self.pef = per_example.PerExampleFlatFishers.load(
            self.pef_filepath,
            n_examples=None,
            # This leads to the Fishers not being loaded, which ends up being much faster.
            start_fisher_index=0,
            end_fisher_index=0,
        )

        self.nmf = nmf_common.SparseNmfDecomposition.load(self.nmf_filepath)
        self.nmf.normalize_components_to_unit_norm()

        # TODO: Allow us to specify othe subsets of example indices.
        if self.pef.input_ids.shape[0] > self.nmf.W.shape[0]:
            self.pef = self.pef.create_for_subset(list(range(self.nmf.W.shape[0])))

        self.examples = self._load_examples()

    def _load_examples(self):
        ds = em_datasets.load('imagenet/default', split=self.ds_split, tokenizer=None, sequence_length=224)
        ds = ds.skip(self.ds_offset)
        ds = ds.take(self.nmf.W.shape[0]).cache().as_numpy_iterator()
        return [
            ImageNetExample(image=b[0] / 255.0, label=b[1])
            for b in ds
        ]

    ############################################################

    def get_examples_by_indices(self, example_indices: Sequence[int]) -> List[ImageNetExample]:
        return [self.examples[i] for i in example_indices]

    ############################################################

    def get_top_example_indices(self, component_index: int, n_examples: int) -> np.ndarray:
        return np.argsort(-self.nmf.W[:, component_index])[:n_examples]
    
    def get_top_examples(self, component_index: int, n_examples: int) -> List[ImageNetExample]:
        inds = self.get_top_example_indices(component_index, n_examples)
        return self.get_examples_by_indices(inds)


##########################################################################


def dumb_plot_h(H: tf.sparse.SparseTensor):
    H = tf.sparse.to_dense(H).numpy()
    plt.plot(H)
    plt.show()


##########################################################################

def example_images_to_single_horizontally(images: List[np.ndarray], pad_size: int) -> np.ndarray:
    padded_images = []
    for i, img in enumerate(images):
        if i != 0:
            img = np.concatenate([np.nan * np.ones_like(img[:, :pad_size]), img], axis=1)
        padded_images.append(img)
    return np.concatenate(padded_images, axis=1)


def example_images_to_single_vertically(images: List[np.ndarray], pad_size: int) -> np.ndarray:
    padded_images = []
    for i, img in enumerate(images):
        if i != 0:
            img = np.concatenate([np.nan * np.ones_like(img[:pad_size]), img], axis=0)
        padded_images.append(img)
    return np.concatenate(padded_images, axis=0)


def plot_examples1(examples: List[ImageNetExample], pad_size: int = 3):
    img = example_images_to_single_horizontally([ex.image for ex in examples], pad_size)
    plt.imshow(img)
    plt.show()


def plot_examples2(examples: List[ImageNetExample], n_rows: int, n_cols: int, pad_size: int = 3, *, save: Optional[str] = None):
    assert n_rows * n_cols == len(examples)
    rows = [examples[i * n_cols: (i + 1) * n_cols] for i in range(n_rows)]
    rows = [[ex.image for ex in row] for row in rows]
    rows = [example_images_to_single_horizontally(imgs, pad_size) for imgs in rows]
    img = example_images_to_single_vertically(rows, pad_size)
    plt.rcParams["figure.figsize"] = (12, 8)
    fig, ax = plt.subplots()
    ax.imshow(img)
    if save is not None:
        fig.savefig(save)
        plt.close(fig)
    else:
        fig.show()


##########################################################################
##########################################################################

cc = ImageNetCC(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    ds_split=SPLIT,
)

hs = cc.nmf.get_full_sparse_H()

"""
Tuned or otherwise interesting component indices:
1, 3, 5, 6, 8, 9, 10

?: 4, 

comp 29 looks tuned for bees and hexagonal-ish patterns (i.e., honeycomb-like)


Fur color/patterns of dogs/animals.
"""

# SAVE_DIR = "/fruitbasket/users/m/imagenet_save_bin1"
SAVE_DIR = "/fruitbasket/users/m/imagenet_save_bin.validation.1"


N_ROWS = 8
N_COLS = 12
N_EX = N_ROWS * N_COLS

COMP_INDEX = 29

imgs = cc.get_top_examples(COMP_INDEX, N_EX)
print([ex.label for ex in imgs])
# plot_examples1(imgs)
plot_examples2(imgs, N_ROWS, N_COLS)
# plot_examples2(imgs, N_ROWS, N_COLS, save=os.path.join(SAVE_DIR, f'___comp{COMP_INDEX}.png'))



# reload(imagenet_x)
# annots = imagenet_x.ImageNetXAnnotations()



reload(imagenet_x);annots = imagenet_x.ImageNetXAnnotations()


# for comp_index in range(cc.nmf.W.shape[1]):
#     imgs = cc.get_top_examples(comp_index, N_EX)
#     plot_examples2(imgs, N_ROWS, N_COLS, save=os.path.join(SAVE_DIR, f'comp{comp_index}.png'))
    



# dumb_plot_h(hs[COMP_INDEX])



# model = em_models.from_pretrained('resnet:resnet50_imagenet')
# model.summary()


R"""

rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/imagenet_save_bin1" \
    "$HOME/Desktop/projects_data/extract_merge1/imagenet1/save_bin/"


rsync -ra -e ssh \
    "m@mango.cs.unc.edu:/fruitbasket/users/m/imagenet_save_bin.validation.1" \
    "$HOME/Desktop/projects_data/extract_merge1/imagenet1/save_bin/"

"""