R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/wino1/wino_improve001.py

"""

import dataclasses
from importlib import reload
import os
import time

from colorama import Fore, Back, Style
import datasets as hfds
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sps
import seaborn as sns
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import ewc
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.tools.nmf import nmf_common
from em.tools.nmf import nmf_transform
from em.util import flat_pack
from em.util import hf_util
from em.util.color_util import cu

from em.projects.anli import anli_misc1 as am
from em.projects.wino import wino_misc1 as wm
from em.projects.wino import nmf_components_fisher as ncf

rocket_cmap = sns.color_palette("rocket", as_cmap=True)

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/winogrange1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

EVAL_BATCH_SIZE = 128
# TRAIN_BATCH_SIZE = 128
TRAIN_BATCH_SIZE = 32
SEQUENCE_LENGTH = 50

PRETRAINED_MODEL = 'bert-base-uncased'
MNLI_MODEL = 'textattack/bert-base-uncased-MNLI'
WINO_MODEL = 'bert_base_mnli_to_winogrande_xl_4_epochs_01'

MNLI_FISHER_FILENAME = 'bert_base_mnli_fisher.32k.h5'

############################################

"""
Coefficient fractor: 0.500   Fraction threshold: 0.900   P-value threshold: 0.2000
    Top examples fraction: 0.202
    Top examples accuracy: 0.982
    Remaining examples accuracy: 0.497
"""

# NCF_IDENTIFIER = 'cf0_500_ft0_900_pvt0_2000'
NCF_IDENTIFIER = 'cf0_300_ft0_900_pvt0_0500'
# NCF_IDENTIFIER = 'cf0_300_ft0_800_pvt0_2000'
NCF_FILENAME = f'nmf_components_fishers.per_sub_block.6k.16k.256.bert_base_mnli_to_winogrande_xl_4_epochs_01.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.{NCF_IDENTIFIER}.h5'

# _NCF_BASE_FILENAME = 'nmf_components_fishers.per_sub_block.6k.16k.256.bert_base_mnli_to_winogrande_xl_4_epochs_01.test_val_winogrande.no_embeddings.sparse_dynamic_raw.32k.32k.h5'
# ncf.print_parameters_and_accuracy_infos_matching_file(os.path.join(FISHERS_DIR, _NCF_BASE_FILENAME))

###############################################################################

ac_fisher = ncf.NmfComponentsFisher.load(os.path.join(FISHERS_DIR, NCF_FILENAME))

c_fisher = ac_fisher.to_correct_dense_diagonal_fisher()
e_fisher = ac_fisher.to_erroring_dense_diagonal_fisher()

# IS DIFFERENT SCALING OF COMPONENTS PER SUBLAYER SOMETHING TO WORRY ABOUT
# OR IS IT ALREADY HANDLED???

############################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

mnli_model = TFAutoModelForSequenceClassification.from_pretrained(MNLI_MODEL, from_pt=True)
mnli_model.bert.embeddings.trainable = False

wino_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, WINO_MODEL), from_pt=False)
wino_model.bert.embeddings.trainable = False

############################################

variable_filter = tmv.VariableFilter(merge_embeddings=False)

mnli_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, MNLI_FISHER_FILENAME))
mnli_fisher = mnli_fisher.fishers
mnli_fisher = variable_filter.filter_parallel_lists(mnli_fisher)

############################################

train_ds = em_datasets.load(
    'winogrande/xl',
    split='train',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
    force_deterministic=False,
)
train_ds = train_ds.repeat().shuffle(1000).batch(TRAIN_BATCH_SIZE)

val_ds = em_datasets.load(
    'winogrande/xl',
    split='validation',
    tokenizer=tokenizer,
    sequence_length=SEQUENCE_LENGTH,
    force_deterministic=True,
)
val_ds = val_ds.batch(EVAL_BATCH_SIZE).cache()

############################################

# NEED TO HANDLE THE LACK OF EMBEDDINGS FOR SOME FISHERS, WHICH
# GETS KINDA ANNOYING.


def new_output_model():
    output_model = TFAutoModelForSequenceClassification.from_pretrained(
        os.path.join(MODELS_DIR, WINO_MODEL), from_pt=False)
    output_model.bert.embeddings.trainable = False
    output_model.compile(
        optimizer=tf.keras.optimizers.Adam(2e-5, clipnorm=0.1),
        metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    )
    return output_model


# N_MERGE_COEFFS = 50

# output_model = new_output_model()

# mergeable_models = [wino_model, mnli_model]
# mergeable_fishers = [c_fisher, mnli_fisher]

# coefficients_set = merging.create_pairwise_grid_coeffs(N_MERGE_COEFFS)

# merge_results = merging.merging_coefficients_search(
#     mergeable_models=mergeable_models,
#     fishers=mergeable_fishers,
#     coefficients_set=coefficients_set,
#     dataset=val_ds,
#     metric=hfds.load_metric("glue", 'mnli'),
#     fisher_floor=1e-6,
#     # favor_target_model=True,
#     favor_target_model=False,
#     normalize_fishers=True,
#     output_model=output_model,
# )

########################
#
# REMEMBER TO TRY EWC-STYLE REGULARIZATION WHEN TRAINING AGAIN BELOW!!!
#
# Maybe also think of an example-based method rather than a component-based
# method where the components are only used to select a subset of examples
# to be used when computing a (batch dense) Fisher.
#
########################


# output_model = TFAutoModelForSequenceClassification.from_pretrained(
#     MNLI_MODEL, from_pt=True)

output_model = TFAutoModelForSequenceClassification.from_pretrained(
    PRETRAINED_MODEL, from_pt=True, num_labels=2)


# output_model.bert.embeddings.set_weights(wino_model.bert.embeddings.get_weights())
output_model.bert.embeddings.trainable = False

output_model.compile(
    optimizer=tf.keras.optimizers.Adam(2e-5, clipnorm=0.1),
    metrics=tf.keras.metrics.SparseCategoricalAccuracy(),
    loss=ewc.EwcLoss(
        tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        model_variables=hf_util.get_mergeable_variables(output_model),
        ewc_variables=hf_util.get_mergeable_variables(wino_model),
        # ewc_fishers=c_fisher,
        ewc_fishers=ac_fisher.batch_correct_fishers,
        # ewc_fishers=ac_fisher.batch_erroring_fishers,
        # ewc_fishers=[
        #     c + e
        #     for c, e in zip(ac_fisher.batch_correct_fishers, ac_fisher.batch_erroring_fishers)
        # ],
        # lmbda=1e-2,
        # lmbda=1e10,
        # lmbda=1e7,
        # lmbda=1e3,
        # lmbda=1e0,
        lmbda=1e2,
    )
)
output_model.bert.embeddings.trainable = True

output_model.fit(train_ds, steps_per_epoch=256, epochs=32, validation_data=val_ds)

########################

def combine_ce_fishers(weight_correct: float = 1.0):
    return [
        weight_correct * c + (1 - weight_correct) * e
        for c, e in zip(c_fisher, e_fisher)
    ]


# TRY FROM BERT INSTEAD OF MNLI MODEL.
mnli_model = TFAutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL, from_pt=True)
mnli_model.bert.embeddings.trainable = False

# mnli_model.set_weights([-x for x in mnli_model.get_weights()])


output_model = new_output_model()
new_output_model_weights = output_model.get_weights()


mergeable_models = [wino_model, mnli_model]

# mergeable_fishers = [c_fisher, mnli_fisher]
# mergeable_fishers = [e_fisher, mnli_fisher]
# mergeable_fishers = [mnli_fisher, mnli_fisher]

# ac_fisher.batch_correct_fishers
# ac_fisher.batch_erroring_fishers


# mergeable_fishers = [c_fisher, e_fisher]
# mergeable_fishers = [e_fisher, c_fisher]

# mergeable_fishers = [combine_ce_fishers(1), mnli_fisher]

mergeable_fishers = [
    ac_fisher.batch_correct_fishers,
    ac_fisher.batch_erroring_fishers,
]


# coefficients_set = [(0.1, 0.9)]
# coefficients_set = [(0.05, 0.95)]
coefficients_set = [(0.5, 0.5)]
# coefficients_set = [(0.02, 0.98)]
# coefficients_set = [(0.15, 0.85)]
# coefficients_set = [(0.999, 0.001)]

output_model.set_weights(new_output_model_weights)
merge_results = merging.merging_coefficients_search(
    mergeable_models=mergeable_models,
    fishers=mergeable_fishers,
    coefficients_set=coefficients_set,
    dataset=val_ds,
    metric=hfds.load_metric("glue", 'mnli'),
    # fisher_floor=1e-6,
    fisher_floor=1e-20,
    favor_target_model=False,
    normalize_fishers=True,
    output_model=output_model,
)


output_model.evaluate(train_ds.take(2048 // TRAIN_BATCH_SIZE))

output_model.fit(train_ds, steps_per_epoch=256, epochs=64, validation_data=val_ds)
# output_model.fit(train_ds, steps_per_epoch=256, epochs=32, validation_data=val_ds)

#
