R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/m_npeff/qqp/wrong_components_search002.py

"""

from importlib import reload
import os

import numpy as np
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.fishers import lrm_pefs
from em.projects.pi import qqp_components_context as QCC
from em.tools.nmf import lrm_npeff

from em.util.color_util import cu

###############################################################################

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers"

# NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.mnpeff.256comps.001.h5"
# NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.expansion_001.coeffs_fit001.h5"
NMF_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.wrongs_only.expansion_001.coeffs_fit_to_validation001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

###############################################################################

PEFS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/qqp_lrm_npeff2/per_example_fishers"

# PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.heldout_from_train.50000ex.65536.h5"
PEFS_NAME = "bert_base_qqp_50k_holdout_4_epochs_01_epoch9.validation.40430ex.65536.h5"
PEFS_PATH = os.path.join(PEFS_DIR, PEFS_NAME)

###############################################################################

TOKENIZER = 'bert-base-uncased'
# SPLIT = 'train[-50000:]'
SPLIT = 'validation'

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

print('Starting to load in nmf.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)

print('Starting to load in logits.')
logits = lrm_pefs.SparseLrmPefs.load_logits(PEFS_PATH)

print('Starting to create evaluation context.')
eval_ctx = QCC.EvaluationContext2.create_from_ds_and_logits(
    ds=em_datasets.load('glue/qqp', split=SPLIT, sequence_length=128, tokenizer=tokenizer),
    logits=logits,
)
print('Evaluation context made.')

###############################################################################


def get_wrong_components(threshold=0.5, n_examples=16):
    labels = eval_ctx.all_examples[1]
    predictions = np.argmax(eval_ctx.og_logits, axis=-1)
    correct_indicator = (labels == predictions).astype(np.float64)
    #
    ret = []
    #
    for i in range(nmf.W.shape[1]):
        top_inds = np.argsort(-nmf.W[:, i])[:n_examples]
        correct_fraction = np.mean(correct_indicator[top_inds])
        if correct_fraction <= threshold:
            ret.append(i)
    return ret
    

print(get_wrong_components())
print(get_wrong_components(threshold=.6))
print(get_wrong_components(threshold=.7))
print(get_wrong_components(threshold=.8))
print(get_wrong_components(threshold=.85))
print(get_wrong_components(threshold=.9))


"""
# n_examples=16

[44]
[5, 18, 34, 44]
[0, 5, 10, 12, 18, 22, 26, 34, 37, 44, 54, 59, 65, 222]
[0, 2, 5, 10, 12, 16, 17, 18, 22, 26, 27, 34, 36, 37, 39, 41, 44, 45, 48, 54, 59, 65, 222]
[0, 2, 5, 10, 12, 16, 17, 18, 21, 22, 26, 27, 34, 36, 37, 39, 41, 44, 45, 48, 51, 54, 59, 60, 65, 222, 249]
[0, 2, 5, 10, 12, 13, 16, 17, 18, 19, 21, 22, 24, 26, 27, 29, 30, 33, 34, 36, 37, 39, 41, 44, 45, 48, 51, 53, 54, 56, 59, 60, 61, 62, 63, 65, 181, 222, 229, 236, 249, 279, 306



[1, 3, 4, 6, 12, 18, 22, 23, 27, 30, 31, 32, 34, 36, 38, 41, 43, 44, 45, 50, 51, 52, 54, 55, 57, 59]
[1, 3, 4, 5, 6, 11, 12, 13, 18, 19, 21, 22, 23, 24, 25, 27, 28, 30, 31, 32, 34, 35, 36, 38, 41, 43, 44, 45, 46, 50, 51, 52, 54, 55, 57, 59]
[0, 1, 3, 4, 5, 6, 8, 11, 12, 13, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 50, 51, 52, 54, 55, 56, 57, 59, 61, 62, 63]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 62, 63]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 43, 44, 45, 46, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 65, 222]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 222, 233]

# "textattack_bert_qqp.mnpeff.train_to_validation.coeffs_fit001.h5":
#     threshold=0.5: []
#     threshold=0.8: []
#     threshold=0.85: [66, 220]
#     threshold=0.9: [63, 66, 80, 173, 220]

# "textattack_bert_qqp.mnpeff.test_to_validation.coeffs_fit001.h5":
#     threshold=0.85: []
#     threshold=0.9: [102, 153, 200]

# "textattack_bert_qqp.mnpeff.test_to_validation.1024comps.coeffs_fit001.h5":
#     threshold=0.8: []
#     threshold=0.85: [181, 258, 734, 963]
#     threshold=0.9:[39, 87, 181, 258, 665, 734, 750, 773, 903, 913, 959, 963]


"""
