R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/m_npeff/snli2/wrong_comps_search001.py

"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import perturbation_finder
from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em import datasets as em_datasets
from em.fishers import lrm_pefs

from em.util.color_util import cu


###############################################################################

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli2_lrm_npeff/per_example_fishers/"
NMF_NAME = "bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.512comps.expansion64comps.no_full_join.001.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli2_lrm_npeff/per_example_fishers/"
PEFS_NAME = "bert_base_snli_150k_holdout_4_epochs_01_epoch2.heldout_from_train_2.50000ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

TOKENIZER = 'bert-base-uncased'

###############################################################################

TOKENIZER = 'bert-base-uncased'
SPLIT = 'train[-100000:-50000]'

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

print('Starting to load in nmf.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

print('Starting to load in logits.')
logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)

print('Starting to create evaluation context.')
eval_ctx = QCC.EvaluationContext2.create_from_ds_and_logits(
    ds=em_datasets.load('snli/default', split=SPLIT, sequence_length=128, tokenizer=tokenizer),
    logits=logits,
)
print('Evaluation context made.')

###############################################################################


def get_wrong_components(threshold=0.5, n_examples=16):
    labels = eval_ctx.all_examples[1]
    predictions = np.argmax(eval_ctx.og_logits, axis=-1)
    correct_indicator = (labels == predictions).astype(np.float64)
    #
    ret = []
    #
    for i in range(nmf.W.shape[1]):
        top_inds = np.argsort(-nmf.W[:, i])[:n_examples]
        correct_fraction = np.mean(correct_indicator[top_inds])
        if correct_fraction <= threshold:
            ret.append(i)
    return ret
    

print(get_wrong_components())
print(get_wrong_components(.25))

"""
threshold=0.5, n_examples=16
[0, 60, 126, 201, 569]

>>> print(get_wrong_components(.4, 8))
[5, 17, 22, 201, 337, 371]
>>> print(get_wrong_components(.25, 8))
[17]



Component 214: Inconsistent labelling of smiling implying is happy.

Component 201, 371: Outside stuff.

"""