R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/m_npeff/qqp/wrong_components_search001.py

"""

from importlib import reload
import os

import numpy as np
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.fishers import lrm_pefs
from em.projects.pi import qqp_components_context as QCC
from em.tools.nmf import lrm_npeff

from em.util.color_util import cu

###############################################################################

# NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers"

# NMF_NAME = "textattack_bert_qqp.mnpeff.train_to_validation.coeffs_fit001.h5"
# NMF_NAME = "textattack_bert_qqp.mnpeff.test_to_validation.coeffs_fit001.h5"
NMF_NAME = "textattack_bert_qqp.mnpeff.test_to_validation.1024comps.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

###############################################################################

# PEFS_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers"

PEFS_NAME = "textattack_bert_qqp.validation.40430ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

###############################################################################

MODEL = "textattack/bert-base-uncased-QQP"
TOKENIZER = 'bert-base-uncased'
SPLIT = 'validation'

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

print('Starting to load in nmf.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

print('Starting to load in logits.')
logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)

print('Starting to create evaluation context.')
eval_ctx = QCC.EvaluationContext2.create_from_ds_and_logits(
    ds=em_datasets.load('glue/qqp', split=SPLIT, sequence_length=128, tokenizer=tokenizer),
    logits=logits,
)
print('Evaluation context made.')

###############################################################################


def get_wrong_components(threshold=0.5, n_examples=16):
    labels = eval_ctx.all_examples[1]
    predictions = np.argmax(eval_ctx.og_logits, axis=-1)
    correct_indicator = (labels == predictions).astype(np.float64)
    #
    ret = []
    #
    for i in range(nmf.W.shape[1]):
        top_inds = np.argsort(-nmf.W[:, i])[:n_examples]
        correct_fraction = np.mean(correct_indicator[top_inds])
        if correct_fraction <= threshold:
            ret.append(i)
    return ret
    

print(get_wrong_components())
print(get_wrong_components(threshold=.6))
print(get_wrong_components(threshold=.7))
print(get_wrong_components(threshold=.8))
print(get_wrong_components(threshold=.85))
print(get_wrong_components(threshold=.9))


"""
n_examples=16

"textattack_bert_qqp.mnpeff.train_to_validation.coeffs_fit001.h5":
    threshold=0.5: []
    threshold=0.8: []
    threshold=0.85: [66, 220]
    threshold=0.9: [63, 66, 80, 173, 220]

"textattack_bert_qqp.mnpeff.test_to_validation.coeffs_fit001.h5":
    threshold=0.85: []
    threshold=0.9: [102, 153, 200]

"textattack_bert_qqp.mnpeff.test_to_validation.1024comps.coeffs_fit001.h5":
    threshold=0.8: []
    threshold=0.85: [181, 258, 734, 963]
    threshold=0.9:[39, 87, 181, 258, 665, 734, 750, 773, 903, 913, 959, 963]


"""
