R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ll/hans_analysis_03.py

"""
from importlib import reload
import itertools
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_util
from em.projects.ll import hans_analysis as ha
from em.projects.wino import nmf_components_fisher as ncf
from em.tools.clustering import vat


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
FROM_PT = True

N_DECOMPS = 25

##########################################################################

# PEF_FILENAME = "feather_berts_{model_number}.hans.no_embeddings.16k.16k.h5"
# NMF_FILENAME = "nmf_decomp.per_sub_block.16k.16k.256.{pef}"

PEF_FILENAME = "feather_berts_{model_number}.hans_lone.no_embeddings.5k.32k.h5"
NMF_FILENAME = "nmf_decomp.per_sub_block.5k.32k.256.{pef}"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


def get_model(model_number: int):
    model = TFAutoModelForSequenceClassification.from_pretrained(
        f'connectivity/feather_berts_{model_number}', from_pt=True)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def make_container(model_number: int):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    nmf_file = NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=False,
    )
    hans_util.fix_up_hans_container(container)
    container.nmfs.force_load_all()
    #
    return container


##########################################################################


container0 = make_container(0)
container1 = make_container(1)
container15 = make_container(15)
container25 = make_container(25)

CONTAINERS = [(0, container0), (1, container1), (15, container15), (25, container25)]

##########################################################################


def plot_vat_comp_sim(container, nmf_index: int):
    H = container.nmfs[nmf_index].H
    H = H / (np.sqrt(np.sum(H**2, axis=-1, keepdims=True)) + 1e-9)
    sim = H @ H.T
    sim2, permutation = vat.vat_reorder_dissimilarity_matrix(sim)
    plt.imshow(
        sim2,
        vmin=0,
        vmax=1,
        cmap=sns.color_palette("rocket", as_cmap=True),
        interpolation=None,
    )
    plt.show()


plot_vat_comp_sim(container0, 18)


##########################################################################


def do_pairs_thing(container_correct, container_incorrect, selection_parameters):
    indicator = (container_incorrect.predictions == 0) & (container_correct.predictions == 1)
    comp_infos_correct = ncf.get_components_appearing_tuned(
        container_correct,
        indicator=indicator,
        selection_parameters=selection_parameters,
    )
    comp_infos_incorrect = ncf.get_components_appearing_tuned(
        container_incorrect,
        indicator=indicator,
        selection_parameters=selection_parameters,
    )
    return comp_infos_correct, comp_infos_incorrect


def get_example_inds(info):
    return set(e.index for e in info.labeled_examples)


def print_matches(infos1, infos2, min_frac=0.3, same_layer_only=True):
    for info1, info2 in itertools.product(infos1, infos2):
        if same_layer_only and info1.nmf_index != info2.nmf_index:
            continue
        inds1 = get_example_inds(info1)
        inds2 = get_example_inds(info2)
        frac_shared = len(inds1 & inds2) / min(len(inds1), len(inds2))
        if frac_shared >= min_frac:
            print(f'{frac_shared:.3f}: [{info1.nmf_index} {info1.component_index}], [{info2.nmf_index} {info2.component_index}]')


# selection_parameters = ncf.SelectionParameters(
#     coeff_factor=0.4,
#     frac_threshold=0.75,
#     p_value_threshold=0.05,
# )

# tuned_comp_infos0, tuned_comp_infos15 = do_pairs_thing(container0, container15, selection_parameters)
# # tuned_comp_infos0, tuned_comp_infos15 = do_pairs_thing(container1, container15, selection_parameters)
# print(len(tuned_comp_infos0))
# print(len(tuned_comp_infos15))

# print_matches(tuned_comp_infos0, tuned_comp_infos15)


# # print_example_for_component(self, example: NliExample, nmf_index: int, component_index: int)

# # 1.000: [6 191], [6 245]
# # 0.667: [18 108], [18 161]

# # container0.print_top_examples(6, 191, 8)
# # container15.print_top_examples(6, 245, 8)

# container0.print_top_examples(18, 108, 8)
# container15.print_top_examples(18, 161, 8)
