R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i em/projects/pi/exps/mains/pe_interprets/snli_ex_interpret_01.py

"""

from importlib import reload
import os
import time

import numpy as np
from transformers import AutoTokenizer

from em.fishers import per_example
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_container
from em.projects.ll import hans_util
from em.tools.nmf import nmf_common

from em.util.color_util import cu

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

TOKENIZER = 'bert-base-uncased'

###################################################################################

MODEL = "connectivity/feather_berts_0"

og_pef_name = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

###################################################################################


def make_container(
    PEF_FILENAME, NMF_FILENAME,
    model_number: int, n_components: int, n_vals_pe: int, min_vals_per_param: int, n_examples: int,
):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    nmf_file = NMF_FILENAME.format(
        n_components=n_components,
        n_vals_pe=n_vals_pe,
        min_vals_per_param=min_vals_per_param,
        n_examples=n_examples,
        pef=pef_file,
        model_number=model_number,
    )
    #
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
    #
    pef = per_example.PerExampleFlatFishers.load(
        os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    #
    if nmf_file.startswith("spH.") or ".spH." in nmf_file:
        # Assume this means that this is a sparsified NMF decomposition.
        nmf = nmf_common.SparseNmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file))
        # nmf = nmf_common.SparseNmfDecomposition.load(os.path.join("/fruitbasket/users/m/tmp", nmf_file))
    else:
        nmf = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file))
    nmf.normalize_components_to_unit_norm()
    #
    container = am.PefNmfAnalysisContainer(
        pef=pef,
        nmfs=[nmf],
        tokenizer=tokenizer,
        shift_labels=True,
    )
    return container


###################################################################################


###################################################################################
###################################################################################

container = make_container(
    PEF_NAME,
    NMF_NAME,
    model_number=0,
    n_components=512,
    n_vals_pe=65536,
    min_vals_per_param=10,
    n_examples=50000,
)

nmf = container.nmfs[0]

###################################################################################

# EXAMPLE_INDEX = 5777
# EXAMPLE_INDEX = 5439
# EXAMPLE_INDEX = 13490
# EXAMPLE_INDEX = 23490
# EXAMPLE_INDEX = 33490
# EXAMPLE_INDEX = 43490
# EXAMPLE_INDEX = 35959
EXAMPLE_INDEX = 13959

example = container.examples[EXAMPLE_INDEX]

example_coeffs = nmf.W[EXAMPLE_INDEX]
# Normalize to have sum of 1.
example_coeffs = example_coeffs / np.sum(example_coeffs)

sorted_comp_inds = np.argsort(-example_coeffs)
sorted_coeffs = example_coeffs[sorted_comp_inds]

container.print_example_for_component(example, 0, 0)
print(sorted_coeffs[:10])
print(sorted_comp_inds[:10])

# Component 59 looks interesting as one that is wrongly predicting (thinks human is not synonym for stuff like man, child, ...).


###################################################################################

# # Stuff seeing number of components per example.

# sorted_W = -np.sort(-nmf.W, axis=-1)
# cs_W = np.cumsum(sorted_W, axis=-1)
# cs_W /= cs_W[:, -1][:, None]
# mean_cs_W = cs_W.mean(axis=0)
