R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/m_npeff/input_salience/snli_is001.py

"""
from importlib import reload
import os

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import input_salience
from em.projects.m_npeff import perturbation_finder
from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu


###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_006.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################


def inputs_fishers_to_token_saliences(inputs_fisher):
    salience = tf.reduce_sum(tf.square(inputs_fisher), axis=[0, 2])
    return salience / tf.reduce_sum(salience)


def component_inputs_fishers_to_token_saliences(inputs_fisher):
    salience = tf.reduce_sum(tf.square(inputs_fisher), axis=-1)
    return salience / tf.reduce_sum(salience)


def print_token_saliences(x, *token_saliences):
    input_ids = x['input_ids'].numpy()
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    for *sals, token in zip(*token_saliences, tokens):
        if token == tokenizer.pad_token:
            continue
        sals = ' '.join([f'{sal.numpy():.4f}' for sal in sals])
        print(f'{sals} {token}')


###############################################################################

print('Starting to read in decomposition.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
print('Decomposition read in.')
nmf.normalize_components_to_unit_norm()
print('Decomposition components normalized.')

###############################################################################

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
sqeuence_length = 128

ds = em_datasets.load('snli/default', split='train_skip_50k', sequence_length=128, tokenizer=tokenizer)

###############################################################################

# reload(input_salience)
pef_computer = input_salience.TransformerJointLrmPefComputer(
    model=model,
    variables=model.trainable_variables,
)

###############################################################################

# EXAMPLE_INDEX = 0
# EXAMPLE_INDEX = 1
# EXAMPLE_INDEX = 2
# EXAMPLE_INDEX = 3
# EXAMPLE_INDEX = 4
EXAMPLE_INDEX = 5

for x, _ in ds.skip(EXAMPLE_INDEX):
    break

w = nmf.W[EXAMPLE_INDEX]

# TODO: Use something like this to determine n_top_components.
cumsum_w_frac = np.cumsum(-np.sort(-w / np.sum(w)))

###############################################################################

# TODO: Remove dummy batch dim from inputs_fisher.
params_fisher, inputs_fisher = pef_computer.process_example(x)

total_token_saliences = inputs_fishers_to_token_saliences(inputs_fisher)
print_token_saliences(x, total_token_saliences)


# Flatten the params_fishers
params_fisher = tf.concat([
    tf.reshape(f, [f.shape[0], -1])
    for f in params_fisher
], axis=-1)

frob_norm_params_fisher = input_salience.compute_mpef_frobenius_norms(params_fisher)

# The sqrt of the frob_norm is on purpose here.
params_fisher = (params_fisher / tf.sqrt(frob_norm_params_fisher)).numpy()
inputs_fisher = (inputs_fisher / tf.sqrt(frob_norm_params_fisher)).numpy()

# Remove parts corresponding to pad tokens.
inputs_fisher = inputs_fisher[:, x['attention_mask'].numpy() == 1, :]
total_token_saliences2 = total_token_saliences.numpy()[x['attention_mask'].numpy() == 1]
inputs_fisher /= total_token_saliences2[None, :, None]

# Flatten the inputs fisher.
inputs_fisher = np.reshape(inputs_fisher, [inputs_fisher.shape[0], -1])

#

n_top_components = 16
# n_top_components = 2
# n_top_components = 8
# n_top_components = 16
# n_top_components = 32

reload(input_salience)
puter = input_salience.compute_inputs_salience(
    nmf=nmf,
    w=w,
    flattened_params_lrm_fisher=params_fisher,
    flattened_inputs_lrm_fisher=inputs_fisher,
    n_top_components=n_top_components,
    #
    # lmbda_diag_block=1.0,
    lmbda_diag_block=0.0,
)
# loss = puter._compute_loss()

lr = tf.cast(1e1, tf.float32)
# lr = tf.cast(1e2, tf.float32)
# lr = tf.cast(1e-0, tf.float32)
# lr = tf.cast(1e-1, tf.float32)

for i in range(3000):
    c_loss, d_loss = puter._gradient_update_step(lr)
    loss = c_loss + puter.lmbda_diag_block * d_loss
    if (i + 1) % 25 == 0:
        print(loss.numpy())


top_comp_inds = np.argsort(-w)[:n_top_components]

# # 103 appears to have multiple people in its selectivity.
# list(top_comp_inds).index(103)
# 9

top_comp_ind = 0
# top_comp_ind = 1
# top_comp_ind = 2
# top_comp_ind = 3
# top_comp_ind = 9
comp_inputs_pseudo_fisher = tf.reshape(puter.inputs_salience[top_comp_ind], [-1, 768])
token_saliences = component_inputs_fishers_to_token_saliences(comp_inputs_pseudo_fisher)
print_token_saliences(x, token_saliences)
# print_token_saliences(x, token_saliences - 1 / token_saliences.shape[0])
# print_token_saliences(x, token_saliences, total_token_saliences)


print(top_comp_inds)


#

# plt.plot(cumsum_w_frac); plt.show()
# loss = puter._compute_loss()

#


###############################################################################

# token_saliences = inputs_fishers_to_token_saliences(inputs_fisher)

# print_token_saliences(x, token_saliences)

