R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/selective_ablation/se_dev002.py

"""
from importlib import reload
import os
import time
from typing import Sequence, Optional

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

from em.analysis import bert_nmf_analysis as bna
from em.experimental import selective_ablation1 as se1

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/math_datasets_dev1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models1')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers0')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers0')

###############################################################################

TASK = 'math_dataset/original_true_false'
SEQUENCE_LENGTH = 128

MODEL = "og_tf__bert_small__100k_steps"
FROM_PT = False

PRETRAINED_MODEL = "prajjwal1/bert-small"

FISHER = "og_tf__bert_small__100k_steps.dense.32k.h5"
PER_EXAMPLES_FISHERS = "og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP = 'nmf_decomp.8k.4k.128.og_tf__bert_small__100k_steps.no_embeddings.sparse_dynamic_raw.32k.32k.h5'

N_EXAMPLES = 8 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 4096)

###############################################################################

COMPONENT_TO_ABLATE = 28
SIMILAR_COMPONENTS = [62]

###############################################################################
# TODO: In proper code, I can probably multithread/multiprocess this to do all these
# loads below in parallel.
#############################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load fisher.')
start = time.time()
dense_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER))
print('Load saved fishertime: ', time.time() - start)


print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
print('Load saved NMF decomposition time: ', time.time() - start)


print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)

###############################################################################

# We did not compute per-example Fishers for the embeddings, so ignore them.
variable_filter = tmv.VariableFilter(
    merge_embeddings=False,
)

model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=FROM_PT
)

merge_vars = hf_util.get_mergeable_variables(model)
full_fisher = dense_fisher.fishers

merge_vars, full_fisher = variable_filter.filter_parallel_lists(merge_vars, full_fisher)

###############################################################################

flat_packer = flat_pack.FlatPacker([v.shape for v in merge_vars])
assert flat_packer.flat_size == decomp.H.shape[1]

flat_ablate_comp = decomp.H[COMPONENT_TO_ABLATE]
flat_sim_comp = decomp.H[SIMILAR_COMPONENTS[0]]

flat_full_fisher = flat_packer.encode_tf(full_fisher)

###############################################################################

output_model = hf_util.clone_model(model)
output_model.compile(
    metrics=[
        tf.keras.metrics.SparseCategoricalAccuracy(),
        tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True),
    ]
)

output_variables = hf_util.get_mergeable_variables(output_model)
output_variables = variable_filter.filter_parallel_lists(output_variables)

###############################################################################

EVAL_FULL_SPLIT = 'train'
EVAL_FULL_N_EXAMPLES = 4 * 1024

EVAL_COMPONENT_N_EXAMPLES = 64

EVAL_BATCH_SIZE = 512

#######################################

eval_ds_full = em_datasets.load(TASK, split=EVAL_FULL_SPLIT, tokenizer=tokenizer, sequence_length=SEQUENCE_LENGTH)
eval_ds_full = eval_ds_full.take(EVAL_FULL_N_EXAMPLES).cache().batch(EVAL_BATCH_SIZE)


def make_top_component_ds(component: int, n_examples: int = EVAL_COMPONENT_N_EXAMPLES):
    ds = se1.make_dataset_for_top_component_examples(
        W=decomp.W,
        component=component,
        n_examples=n_examples,
        tokenizer=tokenizer,
        pe_fishers_data=pe_fishers_data,
    )
    return ds.batch(EVAL_BATCH_SIZE)


eval_ds_ablate_comp = make_top_component_ds(COMPONENT_TO_ABLATE)
eval_ds_sim_comp = make_top_component_ds(SIMILAR_COMPONENTS[0])


###############################################################################

def update_parameters(delta: np.ndarray, epsilon: float):
    assert epsilon >= 0
    deltas = flat_packer.decode_tf(epsilon * delta.astype(np.float32))
    for v_delta, og_v, out_v in zip(deltas, merge_vars, output_variables):
        out_v.assign(og_v + v_delta)


def evaluate_output_model(datasets: Sequence[tf.data.Dataset], names: Optional[Sequence[str]] = None):
    if names is None:
        names = [f'Dataset {i}' for i in range(len(datasets))]
    else:
        assert len(names) == len(datasets)
    for ds, name in zip(datasets, names):
        _, acc, loss = output_model.evaluate(ds, verbose=0)
        print(f'{name} accuracy: {acc}')
        print(f'{name} loss: {loss}')


##########################################################


def ablate_top_diffs(top_k_ablate: int, top_k_other: int, ablate_comp, other_comp):
    _, top_inds_ablate = tf.math.top_k(ablate_comp, k=top_k_ablate)
    _, top_inds_other = tf.math.top_k(other_comp, k=top_k_other)
    #
    ablate_inds = set(top_inds_ablate.numpy()) - set(top_inds_other.numpy())
    ablate_inds = list(ablate_inds)
    print(len(ablate_inds))
    #
    mask = tf.ones([flat_packer.flat_size], dtype=tf.float32)
    mask = 1 - tf.scatter_nd(
        [[i] for i in ablate_inds],
        tf.ones([len(ablate_inds)], dtype=tf.float32),
        shape=[flat_packer.flat_size])
    #
    unflat_mask = flat_packer.decode_tf(mask)
    for v_mask, og_v, out_v in zip(unflat_mask, merge_vars, output_variables):
        out_v.assign(og_v * v_mask)


##########################################################

ablate_top_diffs(
    10000,
    100,
    flat_ablate_comp,
    flat_sim_comp,
)

# ablate_top_diffs(
#     200_000,
#     1,
#     flat_sim_comp,
#     flat_ablate_comp,
# )
evaluate_output_model([eval_ds_ablate_comp, eval_ds_sim_comp, eval_ds_full], ['Ablate', 'Similar', 'Full'])

##########################################################

# ablator = se1.RiemannianAblator(
#     flat_full_fisher=flat_full_fisher,
#     flat_component_fisher=flat_ablate_comp,
#     normalize_fishers=True,
# )

# # ablator = se1.RiemannianAblator(
# #     flat_full_fisher=flat_sim_comp,
# #     flat_component_fisher=flat_ablate_comp,
# #     normalize_fishers=True,
# # )

# ###################

# alpha = 0.5
# # alpha = 0.1
# # alpha = 0.9

# print('Starting to solve for squared delta.')
# start = time.time()
# sq_delta = ablator.solve_for_sq_delta(alpha)
# print('Solve for squared delta time: ', time.time() - start)

# assert sq_delta is not None
# delta = ablator.randomly_unsquare_delta(sq_delta)

# epsilon = 2e1
# update_parameters(delta, epsilon)
# evaluate_output_model([eval_ds_ablate_comp, eval_ds_sim_comp, eval_ds_full], ['Ablate', 'Similar', 'Full'])
