R"""

cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/pi/wrong_comps_snli001.py

"""
from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import per_example
from em.tools.nmf import nmf_common

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

OG_PEF_NAME = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
OG_H_NAME = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{OG_PEF_NAME}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{OG_H_NAME}"

###############################################################################

pef = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_NAME),
    n_examples=None,
    # This leads to the Fishers not being loaded, which ends up being much faster.
    start_fisher_index=0,
    end_fisher_index=0,
)

nmf = nmf_common.SparseNmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_NAME))


def get_wrong_components(threshold=0.5, n_examples=16):
    labels = (pef.labels + 1) % 3
    predictions = np.argmax(pef.predicted_logits, axis=-1)
    correct_indicator = (labels == predictions).astype(np.float64)
    #
    ret = []
    #
    for i in range(nmf.W.shape[1]):
        top_inds = np.argsort(-nmf.W[:, i])[:n_examples]
        correct_fraction = np.mean(correct_indicator[top_inds])
        if correct_fraction <= threshold:
            ret.append(i)
    return ret


# print(get_wrong_components())
print(get_wrong_components(.4))
"""

0.5, 16:
[11, 17, 28, 30, 38, 45, 50, 52, 54, 59, 97, 116, 128, 134, 143, 146, 150, 153, 158, 185, 192, 217, 243, 255, 264, 270, 273, 285, 292, 300, 303, 307, 308, 313, 314, 331, 332, 346, 352, 369, 374, 375, 378, 395, 402, 431, 435, 445, 470, 475, 490, 492, 496, 498]

.4, 16:
[11, 54, 59, 97, 158, 300, 314, 331, 352, 374, 492]

.3, 16:
[11, 54, 59, 300, 314]

.2, 16:
[11, 300, 314]

.1, 16:
[300]

"""