R"""

Metric-derived sparse fishers.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/transfer1/transfer_dev003.py

"""
from importlib import reload
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

from em.merging import gradient_merging as gm

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/transfer1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = "bert-base-uncased"

COMPONENTS_MODEL = "textattack/bert-base-uncased-MNLI"

MODEL = "frozen_bert_base_rte_001"
FROM_PT = False

TASK = "glue/rte"

SEQUENCE_LENGTH = 128

FISHER = f'{MODEL}.fisher.h5'
PER_EXAMPLES_FISHERS = "bert_base_mnli.sparse_dynamic_metric_derived.32k.16k.h5"
DECOMP = f"nmf_decomp.16k.8k.64.reduced_1.{PER_EXAMPLES_FISHERS}"

N_EXAMPLES = 16 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 8 * 1024)


###############################################################################
# TODO: In proper code, I can probably multithread/multiprocess this to do all these
# loads below in parallel.
#############################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load fisher.')
start = time.time()
dense_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER))
print('Load saved fishertime: ', time.time() - start)

print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
print('Load saved NMF decomposition time: ', time.time() - start)

print('Starting to load saved per-example Fishers.')
start = time.time()
pe_fishers_data = per_example.PerExampleFlatFishers.load(
    os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    n_examples=N_EXAMPLES,
    start_fisher_index=START_FISHER_INDEX,
    end_fisher_index=END_FISHER_INDEX,
)
print('Load saved per-example Fishers time: ', time.time() - start)


###########################################################

frozen_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=False)

components_model = TFAutoModelForSequenceClassification.from_pretrained(
    COMPONENTS_MODEL, from_pt=True)

pretrained_model = TFAutoModelForSequenceClassification.from_pretrained(
    PRETRAINED_MODEL, from_pt=True)

pretrained_variables = hf_util.get_mergeable_variables(pretrained_model)
flat_pretrained_variables = tf.concat([tf.reshape(v, [-1]) for v in pretrained_variables], axis=0)


###########################################################


def sp_de_dot_product(sp, de):
    assert len(sp.shape) == len(de.shape) == 1
    dp = tf.sparse.sparse_dense_matmul(
        tf.sparse.reshape(sp, [1, sp.shape[0]]),
        de[:, None]
    )
    return tf.reshape(dp, []).numpy()


sparse_Hs = decomp.get_full_sparse_H()


frozen_fisher = dense_fisher.as_flat_fisher()
# # NOTE: Technically the components should be multiplied by this, but this
# # is equivalent mathematically and easier to do.
# frozen_fisher *= pe_fishers_data.sq_parameter_deltas
frozen_fisher = tf.linalg.l2_normalize(frozen_fisher)


dot_prods = []
for h in sparse_Hs:
    dp = sp_de_dot_product(h, frozen_fisher)
    dot_prods.append(dp)

# plt.plot(dot_prods);plt.show()
# plt.plot(np.sort(dot_prods));plt.show()


###########################################################

def print_top_examples(component: int, n_examples: int):
    _, inds = tf.math.top_k(decomp.W[:, component], k=n_examples)
    for ind in inds:
        label = pe_fishers_data.labels[ind]
        if isinstance(label, tf.Tensor):
            label = label.numpy()
        # Stuff for MNLI:
        label = (label + 1) % 3
        #
        pred = np.argmax(pe_fishers_data.predicted_logits[ind])
        example = tokenizer.decode(pe_fishers_data.input_ids[ind])
        example = example.replace(tokenizer.pad_token, '')
        # example = example.replace(tokenizer.bos_token, '')
        # example = example.replace(tokenizer.eos_token, '')
        example = example.strip()
        print(f'{label}, {pred}: {example}')


dot_prods = np.array(dot_prods)
lowest_dot_prods = np.argsort(dot_prods)
highest_dot_prods = np.argsort(-dot_prods)

print_top_examples(lowest_dot_prods[0], n_examples=16)
print_top_examples(lowest_dot_prods[1], n_examples=16)
print_top_examples(lowest_dot_prods[2], n_examples=16)

print_top_examples(highest_dot_prods[0], n_examples=16)
print_top_examples(highest_dot_prods[1], n_examples=16)
print_top_examples(highest_dot_prods[2], n_examples=16)


###########################################################
# LR = 1e-5
LR = 4e-5
CLIPNORM = 0.1
# BATCH_SIZE = 32
BATCH_SIZE = 128

TARGET_TASK = 'glue/rte'

##################################

output_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, MODEL), from_pt=False)

output_merge_vars = hf_util.get_mergeable_variables(output_model)

# # See what happens when we freeze the head.
# output_model.classifier.trainable = False

##################################

packer = flat_pack.FlatPacker([v.shape for v in pretrained_variables])

components_by_var = packer.decode_sparse_tf(decomp.H, decomp.reduce_kept_indices)
components_by_var = [tf.sparse.retain(c, c.values != 0.0) for c in components_by_var]

component_vars = hf_util.get_mergeable_variables(components_model)

base_vars = hf_util.get_mergeable_variables(frozen_model)

base_fisher = dense_fisher.fishers

##################################

train_ds = em_datasets.load(
    TARGET_TASK,
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)
train_ds = train_ds.repeat().shuffle(1000).batch(BATCH_SIZE)
val_ds = em_datasets.load(
    TARGET_TASK,
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)
val_ds = val_ds.batch(BATCH_SIZE)

##################################

variable_filter = tmv.VariableFilter(
    merge_embeddings=False,
)

##################################
(
    base_vars,
    components_by_var,
    component_vars,
    output_merge_vars,
    base_fisher,
) = variable_filter.filter_parallel_lists(base_vars, components_by_var, component_vars, output_merge_vars, base_fisher)

base_fisher_sq_norm = tf.reduce_sum([tf.reduce_sum(f**2) for f in base_fisher])
base_fisher = [f * tf.math.rsqrt(base_fisher_sq_norm) for f in base_fisher]

output_train_vars = output_model.classifier.trainable_variables

# reload(gm)
fitter = gm.SparseFitter(
    components_by_var=components_by_var,
    component_variables=component_vars,
    base_variables=base_vars,
    output_merge_variables=output_merge_vars,
    #
    # See what happens when we freeze the head.
    # output_train_variables=[],
    output_train_variables=output_train_vars,
    #
    base_batch_fisher=base_fisher,
    #
    output_model=output_model,
    # normalize_fishers=True,
    normalize_fishers=False,
)
# Normalizing Fishers is causing segfault.


fitter.compile(
    optimizer=tf.keras.optimizers.Adam(LR, clipnorm=CLIPNORM),
    metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)

fitter.fit(
    train_ds,
    steps_per_epoch=128,
    epochs=100,
    validation_data=val_ds,
)
