R"""



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/gradient_merging_dev/gm_dev002.py

"""
from importlib import reload
import os
import re
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import per_example
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.util import flat_pack
from em.util import hf_util

from em.merging import gradient_merging as gm


###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/gradient_merging1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

CTASK = 'math_dataset/mnli'
SEQUENCE_LENGTH = 128

CMODEL = 'prajjwal1/roberta-base-mnli'
PRETRAINED_MODEL = "roberta-base"
FROM_PT = True

PER_EXAMPLES_FISHERS = "roberta_base_mnli.no_embeddings.sparse_dynamic_raw.32k.32k.h5"

DECOMP = f'nmf_decomp.8k.16k.64.{PER_EXAMPLES_FISHERS}'
N_EXAMPLES = 8 * 1024
START_FISHER_INDEX, END_FISHER_INDEX = (0, 4 * 4096)


###############################################################################


###############################################################################
# TODO: In proper code, I can probably multithread/multiprocess this to do all these
# loads below in parallel.
#############################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

print('Starting to load saved NMF decomposition.')
start = time.time()
decomp = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP))
decomp.normalize_components_to_unit_norm()
# decomp.H = decomp.get_full_H()
print('Load saved NMF decomposition time: ', time.time() - start)

###############################################################################

W = decomp.W
H = decomp.H

# from em.analysis import bert_nmf_analysis as bna
# bna.print_top_examples(decomp.W, tokenizer, pe_fishers_data.input_ids, pe_fishers_data.labels, n_examples=16, component=0)

#######################################

# We did not compute per-example Fishers for the embeddings, so ignore them.
variable_filter = tmv.VariableFilter(
    merge_embeddings=False,
)

component_model = TFAutoModelForSequenceClassification.from_pretrained(
    CMODEL, from_pt=FROM_PT
)

component_vars = hf_util.get_mergeable_variables(component_model)
component_vars = variable_filter.filter_parallel_lists(component_vars)

###############################################################################

reload(flat_pack)
packer = flat_pack.FlatPacker([v.shape for v in component_vars])
components_by_var = packer.decode_sparse_tf(decomp.H, decomp.reduce_kept_indices)

components_by_var = [tf.sparse.retain(c, c.values != 0.0) for c in components_by_var]


###############################################################################


base_model = TFAutoModelForSequenceClassification.from_pretrained(
    PRETRAINED_MODEL, from_pt=FROM_PT
)

base_vars = hf_util.get_mergeable_variables(base_model)
base_vars = variable_filter.filter_parallel_lists(base_vars)

###############################################################################

TARGET_TASK = 'glue/rte'
BATCH_SIZE = 32
LR = 1e-5
# LR = 2e-4
# LR = 4e-5
CLIPNORM = 0.1

#######################################

output_model = TFAutoModelForSequenceClassification.from_pretrained(
    PRETRAINED_MODEL,
    from_pt=FROM_PT,
    num_labels=em_datasets.n_classes_for_task(TARGET_TASK)
)
# Set the output model's embeddings to the component model's embedding since we
# are not merging them. 
output_model.roberta.embeddings.set_weights(component_model.roberta.embeddings.get_weights())

output_merge_vars = hf_util.get_mergeable_variables(output_model)
output_merge_vars = variable_filter.filter_parallel_lists(output_merge_vars)

output_train_vars = output_model.classifier.trainable_variables

###############################################################################

train_ds = em_datasets.load(
    TARGET_TASK,
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)
train_ds = train_ds.repeat().shuffle(1000).batch(BATCH_SIZE)
val_ds = em_datasets.load(
    TARGET_TASK,
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)
val_ds = val_ds.batch(BATCH_SIZE)

###########################################################

# # Try training the classifier head first.
# output_model.roberta.trainable = False

# output_model.compile(
#     optimizer=tf.keras.optimizers.Adam(3e-4, clipnorm=CLIPNORM),
#     metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
#     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
# )

# output_model.fit(
#     train_ds,
#     steps_per_epoch=256,
#     epochs=8,
#     validation_data=val_ds,
# )


###########################################################


reload(gm)
fitter = gm.SparseFitter(
    components_by_var=components_by_var,
    component_variables=component_vars,
    base_variables=base_vars,
    output_merge_variables=output_merge_vars,
    output_train_variables=output_train_vars,
    output_model=output_model,
    # normalize_fishers=True,
    normalize_fishers=False,
)
# Normalizing Fishers is causing segfault.


fitter.compile(
    optimizer=tf.keras.optimizers.Adam(LR, clipnorm=CLIPNORM),
    metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)

fitter.fit(
    train_ds,
    # steps_per_epoch=1024,
    steps_per_epoch=128,
    # steps_per_epoch=4,
    epochs=100,
    validation_data=val_ds,
)

# Try fitting just classification head first, then initialize coefficients to favor base model.

# 2s / step
# 300ms / step when I do the retain.
