R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=2 python -i local_scripts/transfer1/anli1/anli_per_sublock_nmf_dev002.py

"""
from importlib import reload
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sps
from sklearn.metrics import roc_auc_score
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.mi import continuous_discrete_mi as cd_mi
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util

from em.analysis import annotated_anli_analysis as a3

rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/anli_correct1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "textattack/bert-base-uncased-MNLI"
FROM_PT = True

PRETRAINED_MODEL = 'bert-base-uncased'

NMF_PER_EXAMPLES_FISHERS = "textattack_bert_base_uncased_MNLI.anli_r3.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP = f"nmf_decomp.per_sub_block.16k.16k.256.{NMF_PER_EXAMPLES_FISHERS}"

###############################################################################

START_FISHER_INDEX, END_FISHER_INDEX = (0, 16 * 1024)

model = TFAutoModelForSequenceClassification.from_pretrained(
    MODEL,
    from_pt=FROM_PT,
)
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

###############################################################################

PER_EXAMPLES_FISHERS = "textattack_bert_base_uncased_MNLI.annotated_anli_r3.no_embeddings.sparse_dynamic_raw.all.32k.h5"

print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)

###############################################################################


def load_nmf_decomp(subset_index: int):
    print('Starting to load saved NMF decomposition.')
    start = time.time()
    filename = f"{DECOMP[:-3]}.ssi{subset_index}.h5"
    decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, filename))
    decomp.normalize_components_to_unit_norm()
    # decomp.H = decomp.get_full_H()
    print('Load saved NMF decomposition time: ', time.time() - start)
    return decomp


def get_values_for_subset(pe_fishers_data, packer, subset_index: int):
    start_index, end_index = packer.get_range_for_tensor_by_index(subset_index)

    subset_values = []
    subset_indices = []
    for values, inds in zip(pe_fishers_data.fishers, pe_fishers_data.fisher_indices):
        mask = (start_index <= inds) & (inds < end_index)
        subset_values.append(values[mask])
        subset_indices.append(inds[mask] - start_index)

    return subset_values, subset_indices


###############################################################################

ssi = 20

decomp = load_nmf_decomp(ssi)

###############################################################################

variable_filter = tmv.VariableFilter(merge_embeddings=False)

variables = hf_util.get_mergeable_variables(model)
variables = variable_filter.filter_parallel_lists(variables)


subsets = tmv.group_by_sub_blocks(variables)

# We don't care about the actual shapes, just their number of parameters.
packer = flat_pack.FlatPacker([[sum(tf.size(v) for v in subset)] for subset in subsets])

subset_values, subset_indices = get_values_for_subset(pe_fishers_data, packer, ssi)

print('Starting NMF transform.')
start = time.time()
coeffs = nmf_transform.transform(decomp, subset_values, subset_indices)
print('NMF transform time: ', time.time() - start)


# plt.imshow(coeffs.T, cmap=rocket_cmap);plt.tight_layout();plt.show()


def print_top_examples(coeffs, thing, component: int, n_examples: int):
    _, inds = tf.math.top_k(coeffs[:, component], k=n_examples)
    for ind in inds:
        label = pe_fishers_data.labels[ind]
        if isinstance(label, tf.Tensor):
            label = label.numpy()
        #
        # Stuff for MNLI:
        label = (label + 1) % 3
        #
        pred = np.argmax(pe_fishers_data.predicted_logits[ind])
        example = tokenizer.decode(pe_fishers_data.input_ids[ind])
        example = example.replace(tokenizer.pad_token, '')
        example = example.strip()
        print(f'{label}, {pred}: {thing[ind]} : {example}')


# print_top_examples(coeffs, n_examples=8, component=0)

# # 124
# print_top_examples(coeffs, n_examples=8, component=124)


# print_top_examples(coeffs, n_examples=8, component=118)
# print_top_examples(coeffs, n_examples=8, component=229)
# print_top_examples(coeffs, n_examples=8, component=232)

# print_top_examples(coeffs, n_examples=8, component=178)
# print_top_examples(coeffs, n_examples=8, component=77)
# print_top_examples(coeffs, n_examples=8, component=191)


###############################################################################

######################################################################################################
# NOTE: MI might not be the best for my use-case. I'm looking more for something
# like putting a threshold on the coeffs that guarantees the presence (or maybe also absence)
# of a particular label. Binary discrete variables might make more sense than categorical
# in this usecase.
######################################################################################################

# predicted_labels = np.argmax(pe_fishers_data.predicted_logits, axis=-1)


# reload(cd_mi)

# # mi = cd_mi.compute_mi_fixed_width_bins(coeffs[:, 5], pe_fishers_data.labels, n_bins=5)
# # print(mi)

# mis = []
# for i in range(coeffs.shape[-1]):
#     mi = cd_mi.compute_mi_fixed_width_bins(coeffs[:, i], predicted_labels, n_bins=5)
#     mis.append(mi)

# mis = np.array(mis)
# print(mis)

# # plt.plot(-np.sort(-coeffs[:, 5]));plt.show()
# # plt.plot(-np.sort(-coeffs[:, 124]));plt.show()


###############################################################################


reload(a3)
a3ctx = a3.AaaContext(
    task='r3',
    coeffs=coeffs,
    # The re-ordering is needed due to how the textattack model was trained.
    #
    # TODO: Look to see what the glue processors does to label order for mnli.
    #
    predicted_logits=pe_fishers_data.predicted_logits[:, [1, 2, 0]],
)

# discretes = a3ctx.get_contains_annotation_indicator('NUMERICAL-CARDINAL-AGE')
# discretes = a3ctx.get_contains_annotation_indicator('BASIC-IDIOM')


# mis = []
# for i in range(coeffs.shape[-1]):
#     mi = cd_mi.compute_mi_fixed_width_bins(coeffs[:, i], discretes, n_bins=5)
#     mis.append(mi)

# mis = np.array(mis)
# print(-np.sort(-mis))
# print(np.argsort(-mis))

# print_top_examples(coeffs, n_examples=8, component=227)
# print_top_examples(coeffs, n_examples=8, component=3)
# print_top_examples(coeffs, n_examples=8, component=17)
# print_top_examples(coeffs, n_examples=8, component=106)


# print_top_examples(coeffs, n_examples=8, component=18)
# print_top_examples(coeffs, n_examples=8, component=91)
# print_top_examples(coeffs, n_examples=8, component=125)


# discretes = a3ctx.get_contains_annotation_indicator('NUMERICAL-CARDINAL-AGE')
# discretes = a3ctx.get_contains_annotation_indicator('NUMERICAL')


# discretes = a3ctx.get_contains_annotation_indicator('BASIC-CAUSEEFFECT')
# print(discretes.astype(np.int32).sum())

discretes = ((pe_fishers_data.labels + 1) % 3) == np.argmax(pe_fishers_data.predicted_logits, axis=-1)
print(discretes.astype(np.int32).sum())


aucs = []
for i in range(coeffs.shape[-1]):
    auc = roc_auc_score(discretes, coeffs[:, i], sample_weight=coeffs[:, i])
    aucs.append(auc)

aucs = np.array(aucs)
top_auc_inds = np.argsort(-aucs)

print_top_examples(coeffs, discretes, n_examples=8, component=top_auc_inds[0])
print_top_examples(coeffs, discretes, n_examples=8, component=top_auc_inds[1])
