R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/mech1/mech_dev002.py

"""
from importlib import reload
import os
import re
import time
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.models import bert_activations
from em.tools.clustering import vat
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

from em.analysis import bert_nmf_analysis as bna
from em.analysis import bert_nmf_analysis2 as bna2
from em.experimental import selective_ablation1

from em.analysis import mechanistic1 as mech

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/math_datasets_dev1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models1')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers0')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers0')

###############################################################################

TASK = 'math_dataset/original_true_false'
SEQUENCE_LENGTH = 128

MODEL = "og_tf__bert_small__100k_steps"
FROM_PT = False

PRETRAINED_MODEL = "prajjwal1/bert-small"

FISHER = "og_tf__bert_small__100k_steps.dense.32k.h5"
PER_EXAMPLES_FISHERS = "og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5"

DECOMP = 'nmf_decomp.8k.4k.256.reduced_1.og_tf__bert_small__100k_steps.no_embeddings_no_pooler.sparse_dynamic_raw.32k.32k.h5'
N_EXAMPLES = 8 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 4096)


###############################################################################
# TODO: In proper code, I can probably multithread/multiprocess this to do all these
# loads below in parallel.
#############################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load fisher.')
start = time.time()
dense_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER))
print('Load saved fishertime: ', time.time() - start)


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
decomp.H = decomp.get_full_H()
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)

###############################################################################

# We did not compute per-example Fishers for the embeddings or the pooler, so ignore them.
variable_filter = tmv.VariableFilter(
    merge_embeddings=False,
    merge_pooler=False,
)

finetuned_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=FROM_PT
)

finetuned_vars = hf_util.get_mergeable_variables(finetuned_model)
batch_fishers = dense_fisher.fishers

finetuned_vars, batch_fishers = variable_filter.filter_parallel_lists(finetuned_vars, batch_fishers)

###############################################################################
"""
We are focusing on examples of the rough form: is [integer literal] prime/composite?
"""

reload(mech)
mctx = mech.MechanisticContext(
    decomp=decomp,
    pe_fishers_data=pe_fishers_data,
    tokenizer=tokenizer,
    finetuned_variables=finetuned_vars,
)

subset_labels = [tmv.to_nice_name(v) for v in reversed(finetuned_vars)]

prime_composite_comp_inds = mctx.get_components_matching_regex(
    mech.ExampleRegexes.LITERAL_PRIMALITY_REGEX,
    check_top_k=16,
    min_match_fraction=0.7,
)

permutation = mctx.compute_vat_permutation(prime_composite_comp_inds)
reordered_prime_composite_comp_inds = np.array(prime_composite_comp_inds)[permutation]


# 1, 2, 18, 22
########################################
# Uncomment stuff below for nice plots #
########################################

# sim_matrix = mctx.compute_similarity_matrix(prime_composite_comp_inds)
# bna.plot_sim_matrix(sim_matrix, show=True)

fractions_per_variable = mctx.compute_fractions_per_variable(prime_composite_comp_inds)
bna.plot_component_locations(
    frac_in_subsets=fractions_per_variable,
    subset_labels=subset_labels,
    yticks_fontsize=9,
)

###############################################################################

localizer = bna2.ComponentLocalizationInfo(variables=finetuned_vars)

comp_vars = localizer._packer.decode_tf(decomp.H[prime_composite_comp_inds])

#######################################

# nice_var_name = 'layer3/output/dense/kernel'
nice_var_name = 'layer1/intermediate/dense/kernel'
# nice_var_name = 'layer0/intermediate/dense/kernel'

var_index = subset_labels[::-1].index(nice_var_name)

variable = finetuned_vars[var_index]

comp_var = comp_vars[var_index].numpy()
reordered_comp_var = comp_var[permutation]

#######################################

q = reordered_comp_var / (np.sqrt((reordered_comp_var**2).sum(axis=-1).sum(axis=-1)) + 1e-12)[:, None, None]
S = np.einsum('ijk,ljk->il', q, q)
bna.plot_sim_matrix(S, show=True)

#######################################

reload(bert_activations)


def asdfasdf(example_indices):
    input_ids = pe_fishers_data.input_ids[example_indices]
    examples = {
        'input_ids': input_ids,
        'token_type_ids': tokenizer.pad_token_type_id * np.ones_like(input_ids),
    }
    #
    sequence_mask = input_ids != tokenizer.pad_token_id
    #
    ctx = bert_activations.BertActivationsContext(ba_params, n_layers=4)
    ctx.reset_buffers()
    #
    @tf.function
    def call():
        finetuned_model(examples, training=False)
        sequence_mask_float = tf.cast(sequence_mask, tf.float32)
        activations = [
            [b.activations * sequence_mask_float[..., None] for b in a]
            for a in ctx.get_activations()
        ]
        attn_weights = [a.weights for a in ctx.get_attention_weights()]
        return activations, attn_weights
    #
    with ctx:
        return (sequence_mask, *call())


def ipippsdf(questions: Sequence[str]):
    examples = {'input_ids': [], 'token_type_ids': []}
    for q in questions:
        xq = tokenizer.encode_plus(
            q,
            add_special_tokens=True,
            max_length=SEQUENCE_LENGTH,
            return_token_type_ids=True,
            truncation=True,
            padding='max_length',
            return_tensors='tf',
        )
        examples['input_ids'].append(tf.reshape(xq['input_ids'], [SEQUENCE_LENGTH]))
        examples['token_type_ids'].append(tf.reshape(xq['token_type_ids'], [SEQUENCE_LENGTH]))
    examples['input_ids'] = tf.stack(examples['input_ids'], axis=0)
    examples['token_type_ids'] = tf.stack(examples['token_type_ids'], axis=0)
    #
    sequence_mask = examples['input_ids'] != tokenizer.pad_token_id
    #
    ctx = bert_activations.BertActivationsContext(ba_params, n_layers=4)
    ctx.reset_buffers()
    #
    @tf.function
    def call():
        finetuned_model(examples, training=False)
        sequence_mask_float = tf.cast(sequence_mask, tf.float32)
        activations = [
            [b.activations * sequence_mask_float[..., None] for b in a]
            for a in ctx.get_activations()
        ]
        attn_weights = [a.weights for a in ctx.get_attention_weights()]
        return activations, attn_weights
    #
    with ctx:
        return (sequence_mask, *call())


ba_params = bert_activations.BertActivationsParams(
    layers=[0, 1],
    positions=['ATTENTION_OUTPUT', 'FFW_INTERMEDIATE'],
    activation_grouping='BLOCKWISE_UNIFORM',
    include_attention_weights=True,
)


# Reordered indices of some components with weight on 'layer1/intermediate/dense/kernel':
# 1, 2, 18, 22
reordered_index = 1


example_indices = mctx.get_component_top_example_indices(
    reordered_prime_composite_comp_inds[reordered_index],
    n_examples=32)
sequence_mask, [_, (layer_inputs, layer_outputs)], attn_weights = asdfasdf(example_indices)
# sequence_mask, [_, (layer_inputs, layer_outputs)] = asdfasdf(list(range(16)))

# plt.imshow(reordered_comp_var[reordered_index], cmap=sns.color_palette("rocket", as_cmap=True));plt.show()
# plt.imshow(reordered_comp_var[0] + reordered_comp_var[1], cmap=sns.color_palette("rocket", as_cmap=True));plt.show()

# plt.imshow(np.log(reordered_comp_var[reordered_index] + 1e-12), cmap=sns.color_palette("rocket", as_cmap=True));plt.show()


def plot_sum_of_comp_var_by_axis(reordered_indices, axis):
    for ind in reordered_indices:
        plt.plot(reordered_comp_var[ind].sum(axis=axis))
    plt.show()


# plot_sum_of_comp_var_by_axis([1, 2, 18, 22], axis=0)
plot_sum_of_comp_var_by_axis([1, 2, 18, 22], axis=1)

# Channel indices on the input.
# 444, 276

plt.imshow(layer_inputs[..., 444], cmap=sns.color_palette("rocket", as_cmap=True));plt.show()
# plt.imshow(np.maximum(0, -layer_inputs[..., 444]), cmap=sns.color_palette("rocket", as_cmap=True));plt.show()
plt.imshow(layer_inputs[..., 276], cmap=sns.color_palette("rocket", as_cmap=True));plt.show()

mctx.print_top_examples(reordered_prime_composite_comp_inds[reordered_index], n_examples=32)

# plt.plot(layer_inputs[0, 8]);plt.show()
# plt.plot(layer_outputs[0, 8]);plt.plot(layer_outputs[1, 8]);plt.plot(layer_outputs[2, 8]);plt.plot(layer_outputs[3, 8]);plt.show()

"""
# Large spike in last component's 444 and 276 channels if the question is like "is a prime number?"
# instead of "is prime?".
"""


def imshow_diverging(img):
    v = tf.reduce_max(tf.abs(img)).numpy()
    plt.imshow(img, cmap=sns.color_palette("vlag", as_cmap=True), vmin=-v, vmax=v)
    plt.tight_layout()
    plt.show()


questions = [
    # 'is 6324 a prime number?',
    'is 6421 prime?',
    # 'is 6421 a composite number?',
    'is 6421 composite?',
    # 'is 6424 a prime number?',
    #
    'is 6424 prime?',
    'is 6424 composite?',
    #
    'is 6423 prime?',
    'is 6423 composite?',
    #
    'is 6425 prime?',
    'is 6425 composite?',
]
sequence_mask2, [_, (layer_inputs2, layer_outputs2)], attn_weights2 = ipippsdf(questions)

# imshow_diverging(layer_inputs2[:, :16, 444])
# imshow_diverging(layer_inputs2[:, :16, 276])


def plot_attn_weights(attn_weights_for_example, example: str):
    ex = tokenizer.encode_plus(
        example,
        add_special_tokens=True,
    )
    ex = ex['input_ids']
    # This is number of heads, assuming it is 8 for now for simplicity.
    assert attn_weights_for_example.shape[0] == 8
    fig, axs = plt.subplots(nrows=2, ncols=4, figsize=(9, 4))
    # seqlen = tf.reduce_sum(tf.cast(sequence_mask_for_example, tf.int32)).numpy()
    seqlen = len(ex)
    #
    for i in range(8):
        img = attn_weights_for_example[i, :seqlen, :seqlen]
        img = tf.nn.softmax(img, axis=-1)
        # v = tf.reduce_max(tf.abs(img)).numpy()
        j = i // 4
        k = i % 4
        ax = axs[j, k]
        # ax.imshow(img, cmap=sns.color_palette("vlag", as_cmap=True), vmin=-v, vmax=v)
        ax.imshow(img, cmap=sns.color_palette("rocket", as_cmap=True), vmin=0, vmax=1)
        ax.set_xticks(list(range(img.shape[0])))
        ax.set_xticklabels([tokenizer.convert_ids_to_tokens(t) for t in ex], fontsize=8)
        ax.set_yticks(list(range(img.shape[0])))
        ax.set_yticklabels([tokenizer.convert_ids_to_tokens(t) for t in ex], fontsize=8)
    #   
    plt.tight_layout()
    plt.show()


# 6?
# head_indices: 3, 6? (seem to be token before/after), 7? looks to focus on the suffix
plot_attn_weights(attn_weights2[0][0], questions[0])

# plot_attn_weights(attn_weights2[1][0], questions[0])

# plot_attn_weights(attn_weights2[0][-2], questions[-2])

'''
head 7 [for layer0]: 
    end with 21, 23, 27, 29 => strong positive and selective attention on itself.
    20, 22, 24, 25 => not really special attention to itself.

head 3 [for layer0]:
    Looks to select the number suffix to the "prime" token.
    I see the stuff difference in the values of channels 444 and 276 on the "prime" token.
'''