R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/m_npeff/wrong_heuristics/wrong_comps_search001.py

"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import perturbation_finder
from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu


# ###############################################################################

# # FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
# # FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
# # FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

# # NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
# NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/m_npeff1/per_example_fishers/"
# # NMF_NAME = "test_mnpeff_002.coeffs_fit001.h5"
# # NMF_NAME = "test_mnpeff_004.coeffs_fit001.h5"
# # NMF_NAME = "test_mnpeff_005.coeffs_fit001.h5"
# # NMF_NAME = "test_mnpeff_006.coeffs_fit001.h5"
# NMF_NAME = "test_mnpeff_002.expansion_005.coeffs_fit001.h5"
# NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# # Use this only to get the predictions and example token ids without having to
# # evaluate the model.
# PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
# PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
# PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)

# MODEL = "connectivity/feather_berts_0"
# TOKENIZER = 'bert-base-uncased'

# ###############################################################################

###############################################################################

NMF_DIR = "/fruitbasket/users/m/project_data/extract_merge1/snli3og_lrm_npeff/per_example_fishers"
NMF_NAME = "feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.fit_to_train_skip_50k.h5"
# NMF_NAME = "feather_berts_0.train.50000ex.65536.mnpeff.512comps.001.wrongs_only.expansion_001.no_full_joint.fit_to_train_skip_50k.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)


# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
# PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)


MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

# og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)


# fisher = diagonal.DiagonalFisher.load(FISHER_PATH)
# flat_fisher = fisher.as_flat_fisher().numpy()
# flat_fisher /= np.sqrt(np.sum(flat_fisher**2))

print('Starting to read in decomposition.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=False)
print('Decomposition read in.')
# nmf.normalize_components_to_unit_norm()
# print('Decomposition components normalized.')

ctx = snli_context.SnliContext(
    split='train_skip_50k',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
    load_examples=False,
)
print('SNLI context made.')

###############################################################################

# N_TOTAL_EXAMPLES = 1014
# reload(snli_context); ctx.__class__ = snli_context.SnliContext
# eval_ctx = ctx.create_eval_ctx(og_model)
eval_ctx = ctx.create_eval_ctx_from_pefs_file(PEFS_FOR_PREDICTIONS_PATH)
print('Eval context made.')

# model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

###############################################################################


def get_wrong_components(threshold=0.5, n_examples=16):
    labels = (eval_ctx.all_examples[1] + 1) % 3
    predictions = np.argmax(eval_ctx.og_logits, axis=-1)
    correct_indicator = (labels == predictions).astype(np.float64)
    #
    ret = []
    #
    for i in range(nmf.W.shape[1]):
        top_inds = np.argsort(-nmf.W[:, i])[:n_examples]
        correct_fraction = np.mean(correct_indicator[top_inds])
        if correct_fraction <= threshold:
            ret.append(i)
    return ret
    

print(get_wrong_components())

"""
threshold=0.5, n_examples=16

"test_mnpeff_002.coeffs_fit001.h5" [68, 146, 169, 251]
"test_mnpeff_004.coeffs_fit001.h5" [91, 182, 418, 473]
"test_mnpeff_005.coeffs_fit001.h5" [116, 135, 188]
"test_mnpeff_006.coeffs_fit001.h5" [148, 222, 232]
"test_mnpeff_002.expansion_005.coeffs_fit001.h5" [13, 16, 18, 34, 36, 40, 210, 233, 315]
    [13, 16, 18, 34, 36, 40] are from the expansion.



No expansion:
    [15, 26, 102, 119, 128, 129, 134, 143, 146, 162, 168, 177, 189, 200, 212, 221, 242, 277, 282, 379, 404, 426, 436, 471, 491]
Expansion:
    [0, 1, 2, 4, 9, 14, 17, 18, 19, 21, 22, 24, 27, 28, 30, 34, 38, 40, 42, 45, 50, 54, 55, 59, 62, 63, 79, 90, 166, 193, 198, 207, 210, 226, 232, 241, 253, 276, 285, 335, 341, 346, 443, 468, 490, 535, 547, 555]
"""