R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES= python -i local_scripts/wino1/wino_comps_analysis001.py

"""

import dataclasses
from importlib import reload
import os
import time

from colorama import Fore, Back, Style
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sps
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util
from em.util.color_util import cu

from em.projects.anli import anli_misc1 as am
from em.projects.wino import wino_misc1 as wm
from em.projects.wino import nmf_components_fisher as ncf

rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/winogrange1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'

# MODEL = "bert_base_mnli_to_winogrande_xl_4_epochs_01"
# PER_EXAMPLES_FISHERS = f"{MODEL}.winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"
# unlabeled_indicator = None

# MODEL = "bert_base_mnli_to_winogrande_xl_4_epochs_01"
# PER_EXAMPLES_FISHERS = f"{MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.6k.16k.64.{PER_EXAMPLES_FISHERS}"
# # Dev set has labels, test set does not.
# unlabeled_indicator = np.ones([6_000], dtype=np.bool)
# unlabeled_indicator[:2 * 1_267] = False

MODEL = "bert_base_mnli_to_winogrande_xl_4_epochs_01"
PER_EXAMPLES_FISHERS = f"{MODEL}.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME = f"nmf_decomp.per_sub_block.6k.16k.256.{PER_EXAMPLES_FISHERS}"
# Dev set has labels, test set does not.
unlabeled_indicator = np.ones([6_000], dtype=np.bool)
unlabeled_indicator[:2 * 1_267] = False

FROM_PT = False

N_DECOMPS = 25

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

container = am.load_pef_nmf_analysis_container(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP_FILENAME),
    n_nmfs=N_DECOMPS,
    tokenizer=tokenizer,
    shift_labels=True,
    unlabeled_indicator=unlabeled_indicator,
)

# Do this kinda hack to get labels correct. Having shift_labels=True
# and then doing this gets the string labels and the integer labels
# properly matched.
container.labels = container.pef.labels
container.examples = container._make_nli_examples()


container.nmfs.force_load_all()

# reload(am); container.__class__ = am.PefNmfAnalysisContainer

###############################################################################

# Entropy of predicted logits to select for for overfit examples?

# Filter out components where the highest coefficient is n times bigger than the second highest.
# Look for components that appear correct, make sure to handle the unlabelled examples.
# Maybe transfer from training set components to dev/val components.
# Some automatic selection/annotation of "types/characteristics" of examples.
# Something with pairs of examples showing up in the tops.
# Maybe look at examples with up to c * the coefficient of the top example, where c < 1.

############################################


reload(am); reload(wm); reload(ncf); container.__class__ = am.PefNmfAnalysisContainer

correct_comp_infos = ncf.get_components_appearing_correct(
    container,
    coeff_factor=0.3,
    frac_threshold=0.85,
    #
    # p_value_threshold=0.1,
    p_value_threshold=0.05,
    # p_value_threshold=0.01,
)


example_indices = set()
for info in correct_comp_infos:
    example_indices.update(e.index for e in info.labeled_examples)

total_labeled = 2 * 1_267
print(len(example_indices) / total_labeled)

acc_infos = np.mean([
    float(e.label_char == e.prediction_char)
    for e in container.examples[:total_labeled]
    if e.index in example_indices
])

acc_others = np.mean([
    float(e.label_char == e.prediction_char)
    for e in container.examples[:total_labeled]
    if e.index not in example_indices
])

print(acc_infos)
print(acc_others)


reload(am); reload(wm); reload(ncf); container.__class__ = am.PefNmfAnalysisContainer

correct_fishers, erroring_fishers = ncf.get_apparently_correct_fisher(
    container,
    coeff_factor=0.3,
    frac_threshold=0.85,
    p_value_threshold=0.05,
)


###############################################################################

# # TODO: Actually compute this.
# p_correct = 0.6


# def hobi(factor: float, frac_threshold: float, p_value_threshold: float, max_examples: int = None):
#     for nmf_index in range(container.n_nmfs):
#         n_components = container.nmfs[nmf_index].W.shape[-1]
#         for component_index in range(n_components):
#             top_examples = container.get_top_examples_based_on_relative_coefficient(
#                 nmf_index=nmf_index, component_index=component_index, factor=factor, max_examples=max_examples)
#             top_labeled_examples = [e for e in top_examples if not container.unlabeled_indicator[e.index]]
#             #
#             n_top_ex = len(top_labeled_examples)
#             if n_top_ex <= 1:
#                 continue
#             n_correctly_labeled = sum(int(e.label_char == e.prediction_char) for e in top_labeled_examples)
#             pmf = am._binomial_pmf(n_top_ex, np.arange(n_top_ex + 1), p_correct)
#             #
#             frac = n_correctly_labeled / n_top_ex
#             p_value = pmf[n_correctly_labeled:].sum()
#             if frac >= frac_threshold and p_value <= p_value_threshold:
#                 print(f'{nmf_index} {component_index} {frac:.3f} {p_value}')


# # hobi(factor=0.5, frac_threshold=0.85, p_value_threshold=0.1, max_examples=None)
# # hobi(factor=0.5, frac_threshold=0.85, p_value_threshold=0.05, max_examples=None)
# # hobi(factor=0.4, frac_threshold=0.85, p_value_threshold=0.05, max_examples=None)
# hobi(factor=0.3, frac_threshold=0.85, p_value_threshold=0.05, max_examples=None)


# container.print_top_examples(16, 1, n_examples=16)

# container.print_top_examples(17, 230, n_examples=16)

# container.print_top_examples(21, 5, n_examples=16)
# container.print_top_examples(21, 156, n_examples=16)

# container.print_top_examples(22, 238, n_examples=16)
# container.print_top_examples(22, 239, n_examples=16)

# container.print_top_examples(23, 210, n_examples=16)



"""
0 14 1.000 0.04665599999999999
0 22 0.862 0.002201527783885635
0 41 0.850 0.01596116279000825
0 179 0.900 0.04635740159999999
0 238 0.867 0.02711400077721599
1 194 1.000 0.04665599999999999
2 60 0.909 0.030233087999999988
3 162 0.857 0.03979158110207999
3 248 1.000 0.010077695999999997
4 0 1.000 0.027993599999999993
4 38 0.909 0.030233087999999988
5 120 1.000 0.04665599999999999
5 242 1.000 0.04665599999999999
6 189 0.867 0.02711400077721599
7 35 1.000 0.04665599999999999
7 204 1.000 0.04665599999999999
7 255 1.000 0.04665599999999999
9 47 0.900 0.04635740159999999
9 213 1.000 0.027993599999999993
9 237 1.000 0.016796159999999994
10 92 0.900 0.04635740159999999
13 49 1.000 0.04665599999999999
15 120 1.000 0.027993599999999993
15 133 0.900 0.04635740159999999
15 202 1.000 0.04665599999999999
16 1 0.857 0.0009510010731745533
16 102 1.000 0.04665599999999999
16 238 0.857 0.03979158110207999
17 7 0.885 0.0015917802375868405
17 43 1.000 0.04665599999999999
17 230 0.880 0.0023667688298101095
19 37 1.000 0.027993599999999993
19 121 1.000 0.04665599999999999
19 187 1.000 0.016796159999999994
19 215 1.000 0.04665599999999999
19 234 1.000 0.04665599999999999
20 11 1.000 0.027993599999999993
20 33 0.909 0.030233087999999988
20 50 1.000 0.04665599999999999
20 191 1.000 0.04665599999999999
20 250 1.000 0.027993599999999993
21 5 0.875 0.01833721439846399
21 38 1.000 0.027993599999999993
21 47 0.875 0.01833721439846399
21 71 1.000 0.04665599999999999
21 132 1.000 0.04665599999999999
21 137 0.923 0.012625337548799995
21 156 1.000 0.027993599999999993
21 200 1.000 0.027993599999999993
21 207 1.000 0.04665599999999999
21 208 1.000 0.027993599999999993
21 214 1.000 0.027993599999999993
21 220 1.000 0.006046617599999997
21 230 0.867 0.02711400077721599
22 29 0.900 0.04635740159999999
22 30 1.000 0.04665599999999999
22 32 1.000 0.016796159999999994
22 54 1.000 0.006046617599999997
22 75 1.000 0.006046617599999997
22 87 0.900 0.04635740159999999
22 134 1.000 0.016796159999999994
22 145 0.900 0.04635740159999999
22 162 0.882 0.012318846595891195
22 170 0.867 0.02711400077721599
22 238 0.950 0.0005240493764090262
22 239 1.000 0.04665599999999999
23 19 1.000 0.027993599999999993
23 34 1.000 0.016796159999999994
23 78 1.000 0.016796159999999994
23 130 0.885 0.0015917802375868405
23 137 1.000 0.04665599999999999
23 140 0.857 0.03979158110207999
23 210 0.900 0.04635740159999999
23 233 1.000 0.04665599999999999
23 243 0.923 0.012625337548799995
24 36 1.000 0.04665599999999999
24 48 0.864 0.007563373093010272
24 87 1.000 0.010077695999999997
24 120 1.000 0.016796159999999994
24 125 1.000 0.04665599999999999
24 138 0.923 0.012625337548799995
24 172 1.000 0.04665599999999999
24 174 1.000 0.027993599999999993
24 212 0.947 0.0008327916446810107
24 227 1.000 0.006046617599999997
"""
