R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ll/nmf_rank_explore_01.py

"""
from importlib import reload
import itertools
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import per_example
from em.models import transformer_model_vars as tmv
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_util
from em.projects.ll import hans_analysis as ha
from em.projects.wino import nmf_components_fisher as ncf
from em.tools.clustering import vat
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util
from em.util import sparse_util


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
FROM_PT = True

N_DECOMPS = 25

##########################################################################

PEF_FILENAME = "feather_berts_{model_number}.hans_lone.no_embeddings.5k.32k.h5"
NMF_FILENAME = "nmf_decomp.per_sub_block.5k.32k.{n_components}.{pef_file}"

##########################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


def get_model(model_number: int):
    model = TFAutoModelForSequenceClassification.from_pretrained(
        f'connectivity/feather_berts_{model_number}', from_pt=True)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def get_packer(model):
    variables = hf_util.get_mergeable_variables(model)
    pef_variable_filter = tmv.VariableFilter(merge_embeddings=False)
    variables = pef_variable_filter.filter_parallel_lists(variables)
    subsets = tmv.group_by_sub_blocks(variables)
    # We don't care about the actual shapes, just their number of parameters.
    shapes = [[sum(tf.size(v) for v in subset)] for subset in subsets]
    return flat_pack.FlatPacker(shapes)


def load_pef(model_number: int):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    pef = per_example.PerExampleFlatFishers.load(
        os.path.expanduser(os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file)),
        # n_examples=n_pef_examples,
        # # This leads to the Fishers not being loaded, which ends up being much faster.
        # start_fisher_index=0,
        # end_fisher_index=0,
    )
    return pef


def load_nmf(model_number: int, nmf_index: int, n_components: int):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    nmf_file = NMF_FILENAME.format(n_components=n_components, pef_file=pef_file)
    nmf_file = f"{nmf_file[:-3]}.ssi{nmf_index}.h5"
    #
    decomp = nmf_common.NmfDecomposition.load(os.path.expanduser(os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file)))
    decomp.normalize_components_to_unit_norm()
    return decomp

##########################################################################


def get_values_for_subset(pef, packer: flat_pack.FlatPacker, subset_index: int):
    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)
    subset_values = []
    subset_indices = []
    for values, inds in zip(pef.fishers, pef.fisher_indices):
        mask = (start_index <= inds) & (inds < end_index)
        subset_values.append(values[mask])
        subset_indices.append(inds[mask] - start_index)

    return subset_values, subset_indices


def get_reduced_sparse_fishers(pef, packer: flat_pack.FlatPacker, subset_index: int):
    subset_values, subset_indices = get_values_for_subset(pef, packer, subset_index)
    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)
    reduction_info, (reduced_subset_values, reduced_subset_indices) = sparse_util.remove_always_zero_indices(
        subset_values,
        subset_indices,
        dense_size=end_index - start_index,
        threshold=1,
    )
    # NOTE: IDK if this is correct.
    return reduced_subset_values


##########################################################################

subset_index = 17

pef = load_pef(0)
model = get_model(0)
packer = get_packer(model)
M = get_reduced_sparse_fishers(pef, packer, subset_index)

nmf_r8 = load_nmf(0, nmf_index=17, n_components=8)
