R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/math_ds_writeup1/sa_math001.py

"""
from importlib import reload
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

from em.analysis import bert_nmf_analysis as bna
from em.analysis import bert_selective_ablation_analysis as bsaa
from em.experimental import selective_ablation1

from local_scripts.math_ds_writeup1 import data_dump


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/math_datasets_dev1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models1')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers0')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers0')

###############################################################################

TASK = 'math_dataset/original_true_false'
SEQUENCE_LENGTH = 128

MODEL = "og_tf__bert_small__100k_steps"
FROM_PT = False

PRETRAINED_MODEL = "prajjwal1/bert-small"

FISHER = "og_tf__bert_small__100k_steps.dense.32k.h5"
PER_EXAMPLES_FISHERS = "og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP = 'nmf_decomp.8k.4k.128.og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5'

N_EXAMPLES = 8 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 4096)

###############################################################################

'''
Some "simple" comp indicies: 2, 12, 14, 19, 23, 24, 25, 28, 29, 34, 45, 50, 53, 57, 59, 62

14: Stuff like "are x and x equal?"
28: numbers ending in 5 do not divide odd numbers not ending in 5
34: 2 does not divide odd numbers
45: a number ending in 5 is composite
50: different numbers are not equal
53: different numbers are not equal
57: even numbers do not divide odd numbers
59: non-negative numbers are not larger than negative numbers
62: even numbers do not divide odd numbers
63: a number does not have a different value than itself

14 ablates nicely (only top 16 look completely similar)
29 ablates fairly nicely as well.
34 ablates fairly nicely as well.
59 ablates fairly nicely as well.
63 ablates very nicely

28 and 62 both appear to have parameter representations concentrated in the last layer
'''
# bna.print_top_examples(decomp.W, tokenizer, pe_fishers_data.input_ids, pe_fishers_data.labels, n_examples=8, component=14)

# COMPONENT_TO_ABLATE = 28
# COMPONENT_TO_ABLATE = 14
# COMPONENT_TO_ABLATE = 23
# COMPONENT_TO_ABLATE = 24
# COMPONENT_TO_ABLATE = 29
# COMPONENT_TO_ABLATE = 34
# COMPONENT_TO_ABLATE = 45
# COMPONENT_TO_ABLATE = 50
# COMPONENT_TO_ABLATE = 53
# COMPONENT_TO_ABLATE = 59
COMPONENT_TO_ABLATE = 63

###############################################################################
# TODO: In proper code, I can probably multithread/multiprocess this to do all these
# loads below in parallel.
#############################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load fisher.')
start = time.time()
dense_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER))
print('Load saved fishertime: ', time.time() - start)


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)

###############################################################################

# We did not compute per-example Fishers for the embeddings, so ignore them.
variable_filter = tmv.VariableFilter(
    merge_embeddings=False,
)

model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=FROM_PT
)

merge_vars = hf_util.get_mergeable_variables(model)
fishers = dense_fisher.fishers

merge_vars, fishers = variable_filter.filter_parallel_lists(merge_vars, fishers)

###############################################################################

component = decomp.H[COMPONENT_TO_ABLATE]

flat_packer = flat_pack.FlatPacker([v.shape for v in merge_vars])
assert flat_packer.flat_size == component.shape[0]

# What will take the role of the "fishers" for the component in the merge.
component_fishers = flat_packer.decode_tf(component)

#######################################

# What will take the role of the "parameters" for the component in the merge.
#
# Here I am trying having the variables equal zero. There are other things that
# I can try.
# component_parameters = [tf.zeros_like(v) for v in merge_vars]
# component_parameters = [-2 * v for v in merge_vars]
# component_parameters = [-1.5 * v for v in merge_vars]
# component_parameters = [- v for v in merge_vars]
# component_parameters = [- tf.sign(v) for v in merge_vars]

component_parameters = [- v for v in merge_vars]
# component_parameters = [tf.zeros_like(v) for v in merge_vars]


###############################################################################

output_model = hf_util.clone_model(model)

output_variables = hf_util.get_mergeable_variables(output_model)
output_variables = variable_filter.filter_parallel_lists(output_variables)

###############################################################################

EVAL_FULL_SPLIT = 'train'
EVAL_FULL_N_EXAMPLES = 4 * 1024

# EVAL_COMPONENT_N_EXAMPLES = 64
EVAL_COMPONENT_N_EXAMPLES = 16

EVAL_BATCH_SIZE = 512

#######################################

output_model.compile(
    metrics=[
        tf.keras.metrics.SparseCategoricalAccuracy(),
        tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True),
    ]
)

eval_ds_full = em_datasets.load(TASK, split=EVAL_FULL_SPLIT, tokenizer=tokenizer, sequence_length=SEQUENCE_LENGTH)
eval_ds_full = eval_ds_full.take(EVAL_FULL_N_EXAMPLES).cache().batch(EVAL_BATCH_SIZE)

# bna.print_top_examples(decomp.W, tokenizer, pe_fishers_data.input_ids, pe_fishers_data.labels, n_examples=64, component=COMPONENT_TO_ABLATE)

eval_ds_comp = selective_ablation1.make_dataset_for_top_component_examples(
    W=decomp.W,
    component=COMPONENT_TO_ABLATE,
    n_examples=EVAL_COMPONENT_N_EXAMPLES,
    tokenizer=tokenizer,
    pe_fishers_data=pe_fishers_data,
)
eval_ds_comp = eval_ds_comp.batch(EVAL_BATCH_SIZE)

###############################################################################

GRID_SIZE = 50

permutation = data_dump.COMP_PERMUTATION

FIGS_DIR = os.path.join(EXPS_DIR, 'writeup1', 'images')

#######################################

coefficients_set = merging.create_pairwise_grid_coeffs(GRID_SIZE)

sau = bsaa.SelectiveAblationUtility(
    decomp=decomp,
    pe_fishers_data=pe_fishers_data,
    batch_fisher=fishers,
    original_variables=merge_vars,
    output_model=output_model,
    output_variables=output_variables,
    flat_packer=flat_packer,
    tokenizer=tokenizer,
    full_eval_ds=eval_ds_full,
    permutation=permutation,
)

# reload(bsaa)


def plot_selective_ablation(
    component_index: int,
    n_component_examples_eval: int = 16,
    *,
    show: bool = True,
):
    reordered_comp_index = permutation.tolist().index(component_index)
    #
    first_coeffs, (full_accs, full_losses), (comp_accs, comp_losses) = sau.compute_ablation_selectivity_for_component(
        component_index=component_index,
        n_component_examples_eval=n_component_examples_eval,
        coefficients_set=coefficients_set
    )
    #
    plt.figure(figsize=(8, 4.5))
    ax = plt.subplot(111)
    #
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    #
    ax.get_xaxis().tick_bottom()
    ax.get_yaxis().tick_left()
    #
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    #
    plt.xlabel(R"$\lambda_{\mathrm{comp}}$", fontsize=12)
    plt.ylabel("Cross-Entropy Loss", fontsize=12)
    #
    plt.title(f'Component {reordered_comp_index} Selective Ablation')
    #
    ax.plot(1 - first_coeffs, comp_losses, label='Top Component Examples')
    ax.plot(1 - first_coeffs, full_losses, label='Full Dataset')
    #
    plt.legend(
        loc='best',
        fontsize=12,
        title_fontsize=13,
    )
    #
    plt.tight_layout()
    #
    figpath = os.path.join(FIGS_DIR, 'selective_ablations1', f'ro{reordered_comp_index}_og{component_index}_loss_top{n_component_examples_eval}.svg')
    if show:
        plt.show()
    else:
        figname = f'ro{reordered_comp_index}_og{component_index}_loss_top{n_component_examples_eval}.svg'
        figpath = os.path.join(FIGS_DIR, 'selective_ablations1', figname)
        plt.savefig(figpath, transparent=True)
        plt.close()


_LATEX_FIGURE_TEMPLATE = R"""  
\begin{figure}[h]
\begin{center}
\includesvg[width=0.9\linewidth]{images/selective_ablations1/ro###_og@@@_loss_top$$$.svg}
\caption{\textit{Original component index:} @@@. \textit{Description:} \textbf{[TODO]}}
\label{fig:ro###_selective_ablation_loss}
\end{center}
\vskip -0.1in
\end{figure}"""


def make_latex_figure(
    component_index: int,
    n_component_examples_eval: int = 16,
) -> str:
    reordered_comp_index = permutation.tolist().index(component_index)
    s = _LATEX_FIGURE_TEMPLATE
    s = s.replace('@@@', str(component_index))
    s = s.replace('###', str(reordered_comp_index))
    s = s.replace('$$$', str(n_component_examples_eval))
    return s


# component_index = 63
# component_index = 28
# component_index = 14
component_index = 34


og_comp_inds = [2, 12, 14, 19, 23, 24, 25, 28, 29, 34, 45, 50, 53, 57, 59, 62]
og_comp_inds += [permutation.tolist().index(0), permutation.tolist().index(2)]
og_comp_inds = list(sorted(og_comp_inds, key=lambda x: permutation.tolist().index(x)))


for comp_ind in og_comp_inds:
    # plot_selective_ablation(comp_ind, show=False)
    s = make_latex_figure(comp_ind)
    print(s)
    print('\n')
    pass


# # asdf = -18
# # asdf = -2
# _gen = selective_ablation1.generate_merged_for_coeffs_set(
#     #
#     # output_variables=output_variables[asdf:],
#     # variables_to_merge=[merge_vars[asdf:], component_parameters[asdf:]],
#     # # fishers=[fishers[asdf:], component_fishers[asdf:]],
#     # fishers=[fishers[asdf:], [f**2 for f in component_fishers[asdf:]]],
#     #
#     output_variables=output_variables,
#     variables_to_merge=[merge_vars, component_parameters],
#     fishers=[fishers, component_fishers],
#     # fishers=[fishers, [f**4 for f in component_fishers]],
#     #
#     coefficients_set=coefficients_set,
#     # fisher_floor=1e-3,
#     fisher_floor=1e-8,
#     # fisher_floor=1e-4,
#     normalize_fishers=True,
#     # normalize_fishers=False,
# )

# for coefficients in _gen:
#     _, full_acc, full_loss = output_model.evaluate(eval_ds_full, verbose=0)
#     _, comp_acc, comp_loss = output_model.evaluate(eval_ds_comp, verbose=0)
#     print('Coefficients:', coefficients)
#     print('Full accuracy:', full_acc)
#     print('Full loss:', full_loss)
#     print('Comp accuracy:', comp_acc)
#     print('Comp loss:', comp_loss)
#     print()

