R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ll/hans_analysis_01.py

"""
from importlib import reload
import os
import time

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.projects.anli import anli_misc1 as am

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
FROM_PT = True

N_DECOMPS = 25

##########################################################################

PEF_FILENAME = "feather_berts_{model_number}.hans.no_embeddings.16k.16k.h5"
NMF_FILENAME = "nmf_decomp.per_sub_block.16k.16k.256.{pef}"

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)


def get_model(model_number: int):
    model = TFAutoModelForSequenceClassification.from_pretrained(
        f'connectivity/feather_berts_{model_number}', from_pt=True)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def make_container(model_number: int):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    nmf_file = NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=True,
    )
    container.nmfs.force_load_all()
    return container


###############################################################################

def get_accuracy(container):
    pef = container.pef
    labels = pef.labels
    preds = np.argmax(pef.predicted_logits, axis=-1)
    # Map contradiction predictions to non-entailment.
    preds[preds == 2] = 1
    return (labels == preds).astype(np.float64).mean()


def get_accuracy_ds(model, ds):
    corrects = []
    for x, y in ds:
        logits = model(x, training=False).logits.numpy()
        entailment_logit = logits[:, 1]
        non_entailment_logit = np.maximum(logits[:, 0], logits[:, 2])
        preds = np.argmax(np.stack([entailment_logit, non_entailment_logit], axis=-1), axis=-1)
        # preds = (non_entailment_logit > entailment_logit).astype(np.int32)
        #
        # preds = np.argmax(logits, axis=-1)
        # p = [preds == 0, preds == 1, preds == 2]
        # preds[p[0]] = 1
        # preds[p[1]] = 0
        # preds[p[2]] = 1
        #
        y = y.numpy()
        corrects.append(preds == y)
    return np.concatenate(corrects, axis=0).astype(np.float64).mean()


###############################################################################

# MODEL_NUMBER = 0
# MODEL_NUMBER = 1
MODEL_NUMBER = 15
# MODEL_NUMBER = 25

# MODEL_NUMBER = 3

##########################################################################
model = get_model(MODEL_NUMBER)
tokenizer = AutoTokenizer.from_pretrained(f'connectivity/feather_berts_{MODEL_NUMBER}')

# ds = em_datasets.load('hans/lexical_overlap', split='validation', sequence_length=128, tokenizer=tokenizer)
# ds = em_datasets.load('hans/lexical_overlap', split='validation', sequence_length=64, tokenizer=tokenizer)
# ds = em_datasets.load('hans/lexical_overlap', split='validation', sequence_length=256, tokenizer=tokenizer)
ds = em_datasets.load('hans/lexical_overlap', split='validation', sequence_length=128, tokenizer=tokenizer)

# # NOTE: This will be lower than actual due to contradiction predictions always being considered false.
# # _, acc = model.evaluate(ds.batch(128))
# _, acc = model.evaluate(ds.filter(lambda x, y: y == 1).batch(128))
# print(acc)

# get_accuracy_ds(model, ds.batch(128))
get_accuracy_ds(model, ds.filter(lambda x, y: y == 1).batch(128))
# get_accuracy_ds(model, ds.filter(lambda x, y: y == 0).batch(128))


"""
1: 0.299, 0.0452, 0.0104
15: 0.012
"""


"""
0: 0.4805, 0.9574, 0.0036
15: 0.4966, 0.988, 0.0052
"""

'''
HANS-LO Accuracies:
    1: 0.49559998512268066
    3: 0.49619999527931213

    15: 0.4966000020503998
    25: 0.4959999918937683
'''



##########################################################################

# container = make_container(MODEL_NUMBER)

# acc = get_accuracy(container)
# print(acc)


'''
Accuracies:
    0: 0.48004150390625
    1: 0.4906005859375
    15: 0.48834228515625
    25: 0.49102783203125

'''
