R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ll/hans_analysis_02.py

"""
from importlib import reload
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_util
from em.projects.ll import hans_analysis as ha
from em.projects.wino import nmf_components_fisher as ncf

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
FROM_PT = True

N_DECOMPS = 25

##########################################################################

# PEF_FILENAME = "feather_berts_{model_number}.hans.no_embeddings.16k.16k.h5"
# NMF_FILENAME = "nmf_decomp.per_sub_block.16k.16k.256.{pef}"

PEF_FILENAME = "feather_berts_{model_number}.hans_lone.no_embeddings.5k.32k.h5"
NMF_FILENAME = "nmf_decomp.per_sub_block.5k.32k.256.{pef}"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


def get_model(model_number: int):
    model = TFAutoModelForSequenceClassification.from_pretrained(
        f'connectivity/feather_berts_{model_number}', from_pt=True)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def make_container(model_number: int):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    nmf_file = NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=False,
    )
    hans_util.fix_up_hans_container(container)
    container.nmfs.force_load_all()
    #
    return container


##########################################################################


container0 = make_container(0)
container1 = make_container(1)
container15 = make_container(15)
container25 = make_container(25)

CONTAINERS = [(0, container0), (1, container1), (15, container15), (25, container25)]


print(ha.get_accuracy(container0))
print(ha.get_accuracy(container1))
print(ha.get_accuracy(container15))
print(ha.get_accuracy(container25))

"""
>>> print(ha.get_accuracy(container0))
0.3896
>>> print(ha.get_accuracy(container1))
0.2988
>>> print(ha.get_accuracy(container15))
0.0976
>>> print(ha.get_accuracy(container25))
0.106

"""


# TODO: Look for components selective for incorrect predictions of that model but correct
# predictions on another model (and vice versa). See if we can get components with the same
# set of examples but different predictions matched across the two models.


reload(ha)
ha.plot_component_similarity(container15, container25, 20, comp_sort='avg_coeff', mass_fraction=0.75)
ha.plot_component_similarity(container0, container25, 20, comp_sort='avg_coeff', mass_fraction=0.75)
ha.plot_component_similarity(container0, container1, 20, comp_sort='avg_coeff', mass_fraction=0.75)
ha.plot_component_similarity(container1, container15, 20, comp_sort='avg_coeff', mass_fraction=0.75)

# ha.plot_component_similarity(container15, container25, 18)
# ha.plot_component_similarity(container0, container25, 18)
# ha.plot_component_similarity(container0, container1, 18)


# reload(ha)
# ha.plot_component_similarity(container15, container25, 20, 15, 25, 'Sublayer 20 NMF H Cosine Similarity')
# ha.plot_component_similarity(container0, container25, 19)
# ha.plot_component_similarity(container0, container1, 19)


# ha.plot_frac_true_false(container0)
# ha.plot_frac_true_false(container1)
# ha.plot_frac_true_false(container15)
# ha.plot_frac_true_false(container25)


##########################################################################

# TODO: Try sorting components by decreasing average coeff (or squared coeff) value in the heat maps.

# OUT_DIR = '/fruitbasket/users/m/tmp'

# ha.plot_component_similarity(container15, container25, 16, 15, 25, 'Sublayer 16 NMF H Cosine Similarity', show=False)
# plt.savefig(os.path.join(OUT_DIR, 'hans_lone_models_15_25_sublayer_16_cos_sim.svg'))
# plt.show()

# ha.plot_component_similarity(container0, container25, 16, 0, 25, 'Sublayer 16 NMF H Cosine Similarity', show=False)
# plt.savefig(os.path.join(OUT_DIR, 'hans_lone_models_0_25_sublayer_16_cos_sim.svg'))
# plt.show()

# ha.plot_component_similarity(container0, container1, 16, 0, 1, 'Sublayer 16 NMF H Cosine Similarity', show=False)
# plt.savefig(os.path.join(OUT_DIR, 'hans_lone_models_0_1_sublayer_16_cos_sim.svg'))
# plt.show()

# ha.plot_component_similarity(container15, container1, 16, 15, 1, 'Sublayer 16 NMF H Cosine Similarity', show=False)
# plt.savefig(os.path.join(OUT_DIR, 'hans_lone_models_15_1_sublayer_16_cos_sim.svg'))
# plt.show()

##########################################################################

# OUT_DIR = '/fruitbasket/users/m/tmp'

# for i, container in CONTAINERS:
#     plot_frac_true_false(container, f'MNLI Model {i}', show=False)
#     plt.savefig(os.path.join(OUT_DIR, f'hans_lone_{i}_prediction_selectivity.svg'))
#     plt.show()


"""

download_plot() {
    local filename=$1

    rsync -ra -e ssh \
        "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/$filename.svg" \
        "$HOME/Downloads/$filename.svg"
}

download_plot hans_lone_models_15_25_sublayer_16_cos_sim
download_plot hans_lone_models_0_25_sublayer_16_cos_sim
download_plot hans_lone_models_0_1_sublayer_16_cos_sim
download_plot hans_lone_models_15_1_sublayer_16_cos_sim


download_plot hans_lone_0_prediction_selectivity
download_plot hans_lone_1_prediction_selectivity
download_plot hans_lone_15_prediction_selectivity
download_plot hans_lone_25_prediction_selectivity

"""


##########################################################################


# container = make_container(MODEL_NUMBER)

# acc = get_accuracy(container)
# print(acc)


# reload(hans_util)
# examples = hans_util.get_first_hans_examples('validation', container.pef.input_ids.shape[0])
# # len([e for e in examples if e['heuristic'] == 'lexical_overlap'])

'''
Accuracies:
   1: 0.4298095703125
   15: 0.45623779296875

'''
