R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/wino1/ablate_reg_001.py

"""

import dataclasses
from importlib import reload
import os
import time

from colorama import Fore, Back, Style
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sps
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import regularizers
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util
from em.util.color_util import cu

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf


rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/winogrange1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
SEQUENCE_LENGTH = 50
BATCH_SIZE = 32
EVAL_BATCH_SIZE = 128

MODEL = "bert_base_mnli_to_winogrande_custom_8_epochs_01"
PER_EXAMPLES_FISHERS = f"{MODEL}.winogrande_heldout.no_embeddings.sparse_dynamic_raw.32k.32k.h5"
DECOMP_FILENAME = f"nmf_decomp.per_sub_block.10k.16k.256.{PER_EXAMPLES_FISHERS}"

FROM_PT = False

N_DECOMPS = 25

###############################################################################

WINO_MODEL = os.path.join(MODELS_DIR, MODEL)
MNLI_MODEL = 'textattack/bert-base-uncased-MNLI'

###############################################################################

_NCF_BASE_FILENAME = f'nmf_components_fishers.incorrect.per_sub_block.10k.16k.256.{PER_EXAMPLES_FISHERS}'
# ncf.print_parameters_and_accuracy_infos_matching_file(os.path.join(FISHERS_DIR, _NCF_BASE_FILENAME))

NCF_IDENTIFIER = 'cf0_300_ft0_500_pvt0_0100'
NCF_FILENAME = f'{_NCF_BASE_FILENAME[:-3]}.{NCF_IDENTIFIER}.h5'

###############################################################################

ac_fisher = ncf.NmfComponentsFisher.load(os.path.join(FISHERS_DIR, NCF_FILENAME))

c_fisher = ac_fisher.to_correct_dense_diagonal_fisher()
e_fisher = ac_fisher.to_erroring_dense_diagonal_fisher()

# IS DIFFERENT SCALING OF COMPONENTS PER SUBLAYER SOMETHING TO WORRY ABOUT
# OR IS IT ALREADY HANDLED???

############################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

mnli_model = TFAutoModelForSequenceClassification.from_pretrained(MNLI_MODEL, from_pt=True)
mnli_model.bert.embeddings.trainable = False

wino_model = TFAutoModelForSequenceClassification.from_pretrained(
    WINO_MODEL, from_pt=False)
wino_model.bert.embeddings.trainable = False

############################################

train_ds = em_datasets.load(
    'winogrande/custom',
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
)
train_ds = train_ds.repeat().shuffle(1000).batch(BATCH_SIZE)

heldout_ds = em_datasets.load(
    'winogrande/custom',
    split='heldout',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
    force_deterministic=True,
)
heldout_ds = heldout_ds.take(8 * 1024).batch(EVAL_BATCH_SIZE).cache()

val_ds = em_datasets.load(
    'winogrande/custom',
    split='validation',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
    force_deterministic=True,
)
val_ds = val_ds.batch(EVAL_BATCH_SIZE).cache()


###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

container = am.load_pef_nmf_analysis_container(
    pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS),
    nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP_FILENAME),
    n_nmfs=N_DECOMPS,
    tokenizer=tokenizer,
    shift_labels=True,
)

# Do this kinda hack to get labels correct. Having shift_labels=True
# and then doing this gets the string labels and the integer labels
# properly matched.
container.labels = container.pef.labels
container.examples = container._make_nli_examples()


container.nmfs.force_load_all()

###############################################################################

variable_filter = tmv.VariableFilter(merge_embeddings=False)

###############################################################################

# baseline_variables = hf_util.get_mergeable_variables(baseline_model)
# baseline_variables = variable_filter.filter_parallel_lists(baseline_variables)
# baseline_variables_per_subset = tmv.group_by_sub_blocks(baseline_variables)


def create_variables_per_subset(model, variable_filter):
    variables = hf_util.get_mergeable_variables(model)
    variables = variable_filter.filter_parallel_lists(variables)
    return tmv.group_by_sub_blocks(variables)


def create_components_per_subset(container, ac_fisher):
    components_per_subset = []
    for ssi, nmf in zip(ac_fisher.subset_infos, container.nmfs):
        component_fishers_in_subset = [
            ssi.to_dense_diagonal_fisher(nmf.H[c.component_index])
            for c in ssi.correct_component_infos
        ]
        components_per_subset.append(component_fishers_in_subset)
    return components_per_subset


###############################################################################

components_per_subset = create_components_per_subset(container, ac_fisher)


output_model = TFAutoModelForSequenceClassification.from_pretrained(
    WINO_MODEL, from_pt=False)

output_model.compile(
    optimizer=tf.keras.optimizers.Adam(2e-5, clipnorm=0.1),
    metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
    # loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    loss=regularizers.MultiComponentBySubsetPowerAblationLoss(
        tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        model_variables_per_subset=create_variables_per_subset(output_model, variable_filter),
        baseline_variables_per_subset=create_variables_per_subset(wino_model, variable_filter),
        components_per_subset=components_per_subset,
        #
        epsilon=1e0,
        c=2,
        lmbda=1e0,
    )
)

# - Maybe try to do something with the e_fishers so that we do not disrupt "good" components.
# - Maybe try to do stuff for each component rather than coalesced Fisher.

_, heldout_acc = output_model.evaluate(heldout_ds)
print(cu.hly(f'Original heldout accuracy: {heldout_acc}'))

output_model.fit(train_ds, steps_per_epoch=256, epochs=32, validation_data=heldout_ds)
