R"""


rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/project_data/extract_merge1/math_datasets_dev1/writeup1/images/" \
    "$HOME/Desktop/projects_data/extract_merge1/math_ds_writeup1/images/"




cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/math_ds_writeup1/nmf1_analysis001.py

"""
from importlib import reload
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.clustering import vat
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

from em.analysis import bert_nmf_analysis as bna
from em.experimental import selective_ablation1

from local_scripts.math_ds_writeup1 import data_dump

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/math_datasets_dev1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models1')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers0')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers0')

###############################################################################

TASK = 'math_dataset/original_true_false'
SEQUENCE_LENGTH = 128

MODEL = "og_tf__bert_small__100k_steps"
FROM_PT = False

PRETRAINED_MODEL = "prajjwal1/bert-small"

FISHER = "og_tf__bert_small__100k_steps.dense.32k.h5"
PER_EXAMPLES_FISHERS = "og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP = 'nmf_decomp.8k.4k.128.og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5'

N_EXAMPLES = 8 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 4096)

###############################################################################

'''
Some "simple" comp indicies: 2, 12, 14, 19, 23, 24, 25, 28, 29, 34, 45, 50, 53, 57, 59, 62, 120

14: Stuff like "are x and x equal?"
28: numbers ending in 5 do not divide odd numbers not ending in 5
34: 2 does not divide odd numbers
45: a number ending in 5 is composite
50: different numbers are not equal
53: different numbers are not equal
57: even numbers do not divide odd numbers
59: non-negative numbers are not larger than negative numbers
62: even numbers do not divide odd numbers
63: a number does not have a different value than itself
120: a number is not bigger/smaller than itself

14 ablates nicely (only top 16 look completely similar)
29 ablates fairly nicely as well.
34 ablates fairly nicely as well.
59 ablates fairly nicely as well.
63 ablates very nicely

28 and 62 both appear to have parameter representations concentrated in the last layer
'''
# bna.print_top_examples(decomp.W, tokenizer, pe_fishers_data.input_ids, pe_fishers_data.labels, n_examples=8, component=14)

# COMPONENT = 28
# COMPONENT = 14
# COMPONENT = 23
# COMPONENT = 24
# COMPONENT = 29
# COMPONENT = 34
# COMPONENT = 45
# COMPONENT = 50
# COMPONENT = 53
# COMPONENT = 59
COMPONENT = 63

###############################################################################
# TODO: In proper code, I can probably multithread/multiprocess this to do all these
# loads below in parallel.
#############################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load fisher.')
start = time.time()
dense_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER))
print('Load saved fishertime: ', time.time() - start)


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)

###############################################################################

W = decomp.W
H = decomp.H

#######################################

# We did not compute per-example Fishers for the embeddings, so ignore them.
variable_filter = tmv.VariableFilter(
    merge_embeddings=False,
)

finetuned_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=FROM_PT
)

finetuned_vars = hf_util.get_mergeable_variables(finetuned_model)
batch_fishers = dense_fisher.fishers

finetuned_vars, batch_fishers = variable_filter.filter_parallel_lists(finetuned_vars, batch_fishers)

###############################################################################

# component = H[COMPONENT]

# flat_packer = flat_pack.FlatPacker([v.shape for v in merge_vars])
# assert flat_packer.flat_size == component.shape[0]
###############################################################################

# reload(bna)
# localizer = bna.ComponentLocalizationInfo(variables=finetuned_vars)

# # loc_info = localizer.fraction_per_layer(decomp.H[COMPONENT])

# for i in range(decomp.H.shape[0]):
#     # loc_info = localizer.fraction_per_layer(decomp.H[i])
#     # print(','.join([str(n) for n in loc_info]))
#     print(localizer.fraction_in_pooler(decomp.H[i]))

###############################################################################

cos_sim_matrix = H @ H.T
cos_dissim_matrix = 1 - cos_sim_matrix

# ivat_dissim_matrix = vat.ivat_reorder_dissimilarity_matrix(cos_dissim_matrix)
# ivat_sim_matrix = 1 - ivat_dissim_matrix

# reload(vat)
vat_dissim_matrix, permutation = vat.vat_reorder_dissimilarity_matrix(cos_dissim_matrix)
vat_sim_matrix = 1 - vat_dissim_matrix

# plt.imshow(vat_sim_matrix, vmin=0, vmax=1);plt.show()

# M equals vat_sim_matrix (up to numerical precision)
M = cos_sim_matrix[permutation]
M = M[:, permutation]


#######################################

reordered_frac_in_subsets = data_dump.FRAC_IN_SUBSETS[data_dump.COMP_PERMUTATION]

#######################################

subset_labels = [
    'Pooling',
    'Layer 4',
    'Layer 3',
    'Layer 2',
    'Layer 1',
]

FIGS_DIR = os.path.join(EXPS_DIR, 'writeup1', 'images')


# reload(bna)

# bna.plot_sim_matrix(M, show=False)
# # plt.savefig(os.path.join(FIGS_DIR, 'component_sim_matrix.svg'), transparent=True)
# plt.show()

# bna.plot_component_locations(reordered_frac_in_subsets, subset_labels=subset_labels, vertical_stretch=5, show=False)
# # plt.savefig(os.path.join(FIGS_DIR, 'component_layer_localizations.svg'), transparent=True)
# plt.show()


###############################################################################


def make_example_figure(component_index: int, n_examples_coeff_plot: int = 32, *, save=False):
    reordered_comp_index = permutation.tolist().index(component_index)
    #
    fig, axs = plt.subplots(1, 2, figsize=(12, 3.14159), gridspec_kw={'width_ratios': [1, 15]})
    fig.suptitle(f'Component {reordered_comp_index}', fontsize=22)
    # 
    bna.plot_localization_for_single_component(
        reordered_frac_in_subsets[reordered_comp_index],
        subset_labels=subset_labels,
        size_x=3,
        size_y=3,
        figsize=None,
        ax=axs[0],
        show=False,
    )
    bna.plot_coeffs_for_top_examples(
        decomp,
        permutation,
        n_examples=n_examples_coeff_plot,
        component=component_index,
        figsize=None,
        ax=axs[1],
        show=False,
    )
    if save:
        plt.savefig(os.path.join(FIGS_DIR, 'comp_top_and_loc', f'ro{reordered_comp_index}_og{component_index}_top_{n_examples_coeff_plot}_and_loc.svg'), transparent=True)
    #
    bna.print_top_examples_for_latex(
        decomp,
        pe_fishers_data,
        tokenizer,
        n_examples=8,
        component=component_index,
        permutation=permutation,
        full_figure=True,
    )
    #
    if save:
        plt.close()
    else:
        plt.show()


# make_example_figure(63, save=False)


og_comp_inds = [2, 12, 14, 19, 23, 24, 25, 28, 29, 34, 45, 50, 53, 57, 59, 62]
og_comp_inds = list(sorted(og_comp_inds, key=lambda x: permutation.tolist().index(x)))

for comp_ind in og_comp_inds:
    # make_example_figure(comp_ind, save=True)
    pass


# make_example_figure(permutation.tolist().index(0), save=True)
# make_example_figure(permutation.tolist().index(2), save=True)

bna.print_top_examples_for_latex(
    decomp,
    pe_fishers_data,
    tokenizer,
    n_examples=32,
    component=permutation.tolist().index(2),
    permutation=permutation,
    full_figure=False,
    capitalize_start_of_examples=True
)


###############################################################################

# component_index = 63
# # component_index = 80
# # component_index = 28

# reordered_comp_index = permutation.tolist().index(component_index)

# n_examples_coeff_plot = 32

# # reload(bna)
# # bna.print_top_examples_for_latex(decomp, pe_fishers_data, tokenizer, n_examples=8, component=component_index)


# # raise ValueError("Visually, plot is fine. However, it looks like I'm getting some reversing/permutation stuff mixed up.")


# reload(bna)

# # fig, axs = plt.subplots(1, 2, figsize=(12, 6))
# fig, axs = plt.subplots(1, 2, figsize=(12, 3.14159), gridspec_kw={'width_ratios': [1, 15]})
# # fig, axs = plt.subplots(1, 2, figsize=(12, 3))

# fig.suptitle(f'Component {reordered_comp_index}', fontsize=22)

# # reload(bna)
# bna.plot_localization_for_single_component(
#     reordered_frac_in_subsets[reordered_comp_index],
#     subset_labels=subset_labels,
#     # size_x=7,
#     # size_y=7,
#     size_x=3,
#     size_y=3,
#     # figsize=(3, 5),
#     figsize=None,
#     ax=axs[0],
#     show=False,
# )
# # plt.savefig(os.path.join(FIGS_DIR, f'ro{reordered_comp_index}_og{component_index}_comp_loc.svg'), transparent=True)
# # plt.show()

# # reload(bna)
# # bna.plot_coeffs_for_top_examples(decomp, permutation, n_examples=n_examples_coeff_plot, component=component_index, show=False)
# bna.plot_coeffs_for_top_examples(
#     decomp,
#     permutation,
#     n_examples=n_examples_coeff_plot,
#     component=component_index,
#     figsize=None,
#     ax=axs[1],
#     show=False,
# )
# # plt.savefig(os.path.join(FIGS_DIR, f'ro{reordered_comp_index}_og{component_index}_top_{n_examples_coeff_plot}_and_loc.svg'), transparent=True)
# plt.show()


# reload(bna)
# n_rows = 8
# n_cols = 2
# bna.plot_coeffs_for_top_examples2(
#     decomp,
#     permutation,
#     n_examples=n_examples_coeff_plot,
#     # NOTE: Not sure if this is right or should be inverse permutation.
#     components=permutation[: n_rows * n_cols],
#     n_rows=n_rows,
#     n_cols=n_cols,
#     figsize=(8, 11),
# )

