R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ead/supervised_transfer/exex_merge_dev001.py

"""
import collections
from importlib import reload
import os
import time
from typing import Sequence

import numpy as np
import sympy as sp
import matplotlib.pyplot as plt

import datasets as hfds
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em import datasets as em_datasets
from em.evaluation import evaluation
from em.fishers import diagonal
from em.merging import merging
from em.util import hdf5_util

from em.util.color_util import cu

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
FT_MODEL = "best_base_ead_infix_75k_dev003"
SCOMP_MODEL = "best_base_ead_infix_75k_dev003_to_scomp_exex_001_15k_epoch9"

FT_DS = 'ead_ds_002'
SCOMP_DS = 'scomp_exex_001'

FT_FISHER = 'best_base_ead_infix_75k_dev003.ead_ds_002_train.8096_examples.h5'
SCOMP_FISHER = 'best_base_ead_scomp_exex_15k_001.exex_train.epoch9.8096_examples.h5'

###############################################################################


def compile(model):
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )


tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

ft_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, FT_MODEL), from_pt=False)
compile(ft_model)

scomp_model = TFAutoModelForSequenceClassification.from_pretrained(
    os.path.join(MODELS_DIR, SCOMP_MODEL), from_pt=False)
compile(scomp_model)

###############################################################################


def load_val_ds(name: str):
    return em_datasets.load(
        'ead/infix',
        split=f'{name}.validation',
        tokenizer=tokenizer,
        sequence_length=sequence_length,
    )


sequence_length = 128
EVAL_BATCH_SIZE = 128
# N_VAL_EXAMPLES = 4 * 1024
N_VAL_EXAMPLES = 16 * 1024

#################################################

ft_val_ds = load_val_ds(FT_DS).take(N_VAL_EXAMPLES).batch(EVAL_BATCH_SIZE).cache()
scomp_val_ds = load_val_ds(SCOMP_DS).take(N_VAL_EXAMPLES).batch(EVAL_BATCH_SIZE).cache()

###############################################################################

# ft_model.evaluate(ft_val_ds)
# ft_model.evaluate(scomp_val_ds)

# scomp_model.evaluate(ft_val_ds)
# scomp_model.evaluate(scomp_val_ds)

###############################################################################

ft_fisher = hdf5_util.load_variables_from_hdf5(os.path.join(FISHERS_DIR, FT_FISHER), trainable=False)
scomp_fisher = hdf5_util.load_variables_from_hdf5(os.path.join(FISHERS_DIR, SCOMP_FISHER), trainable=False)

###############################################################################

# N_COEFFS = 20
N_COEFFS = 50
FISHER_FLOOR = 1e-8

acc_metric = hfds.load_metric("glue", 'rte')

#################################################

Score = collections.namedtuple('Score', ['coeffs', 'ft_acc', 'scomp_acc'])


# def print_scores(coeffs, ft_score, scomp_score):
#     print(f"Merging coefficients: {coeffs}")
#     #
#     print("OG Task Score:")
#     for name, value in ft_score.items():
#         print(f"  {name}: {cu.hlb(value)}")
#     #
#     print("SCOMP Task Score:")
#     for name, value in scomp_score.items():
#         print(f"  {name}: {cu.hly(value)}")


def print_score(score):
    print(f"Merging coefficients: {score.coeffs}")
    print(f"    OG Task Score: {cu.hlb(score.ft_acc)}")
    print(f"    SCOMP Task Score: {cu.hly(score.scomp_acc)}")


def do_evaluation(merged_model, coeffs):
    compile(merged_model)
    _, ft_acc = merged_model.evaluate(ft_val_ds, verbose=0)
    _, scomp_acc = merged_model.evaluate(scomp_val_ds, verbose=0)
    return Score(coeffs=coeffs, ft_acc=ft_acc, scomp_acc=scomp_acc)


def get_scores(models, fishers, print_scores=True):
    gen = merging.generate_merged_for_coeffs_set(
        models,
        fishers=fishers,
        coefficients_set=coefficients_set,
        fisher_floor=FISHER_FLOOR,
        favor_target_model=True,
        normalize_fishers=True,
    )
    scores = []
    compiled = False
    for coeffs, merged_model in gen:
        if not compiled:
            compile(merged_model)
            compiled = True
        score = do_evaluation(merged_model, coeffs)
        scores.append(score)
        if print_scores:
            print_score(score)
            print('')
    return scores

#################################################


coefficients_set = merging.create_pairwise_grid_coeffs(N_COEFFS)

scores = get_scores([scomp_model, ft_model], [scomp_fisher, ft_fisher])
print(scores)

scores2 = get_scores([scomp_model, ft_model], None)
print(scores2)

# gen = merging.generate_merged_for_coeffs_set(
#     # TODO: Experiment with the order of these.
#     # [ft_model, scomp_model],
#     [scomp_model, ft_model],
#     #
#     # fishers=[ft_fisher, scomp_fisher],
#     fishers=[scomp_fisher, ft_fisher],
#     # fishers=None,
#     #
#     coefficients_set=coefficients_set,
#     fisher_floor=FISHER_FLOOR,
#     favor_target_model=True,
#     normalize_fishers=True,
# )

# for coeffs, merged_model in gen:
#     ft_score = evaluation.evaluate_model(merged_model, ft_val_ds, acc_metric)
#     scomp_score = evaluation.evaluate_model(merged_model, scomp_val_ds, acc_metric)
#     print_scores(coeffs, ft_score, scomp_score)
#     print()





# TODO: MAKE PLOTS OF SCORES ACROSS MERGING COEFFS FOR THE DIFFERENT VARIANTS TO COMPARE.
#
# TODO: ALLOW FOR MERGING/SWITCHING OF HEADS AS WELL.
#
# THEN LOOK AT HYPERPARAMETERS OF NMF AND OTHER SIMILAR FACTORS!



# Try NMF variations on e.g. natural image patches for quick experimentation of hyperparams
#
# Protein-protein interaction datasets can have some known underlying causative factors,
# e.g. known interacting domains, ... Also interesting application for me, some prior work
# on nmf for protein-protein interactions, so baseline kinda exists.





"""
Merging coefficients: (0.631578947368421, 0.368421052631579)
OG Task Score:
  accuracy: 0.993408203125
SCOMP Task Score:
  accuracy: 0.9541015625

Merging coefficients: (0.5789473684210527, 0.42105263157894735)
OG Task Score:
  accuracy: 0.991943359375
SCOMP Task Score:
  accuracy: 0.967041015625

Merging coefficients: (0.5263157894736842, 0.4736842105263158)
OG Task Score:
  accuracy: 0.985107421875
SCOMP Task Score:
  accuracy: 0.97119140625


>>> ft_model.evaluate(ft_val_ds)
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
32/32 [==============================] - 11s 214ms/step - loss: 0.0550 - sparse_categorical_accuracy: 0.9912
[0.03886324539780617, 0.992431640625]
>>> ft_model.evaluate(scomp_val_ds)
32/32 [==============================] - 7s 214ms/step - loss: 2.0841 - sparse_categorical_accuracy: 0.5496
[2.084089517593384, 0.549560546875]
>>> 
>>> scomp_model.evaluate(ft_val_ds)
WARNING:tensorflow:The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
WARNING:tensorflow:The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
32/32 [==============================] - 9s 214ms/step - loss: 0.0993 - sparse_categorical_accuracy: 0.9871
[0.10470808297395706, 0.985595703125]
>>> scomp_model.evaluate(scomp_val_ds)
32/32 [==============================] - 7s 214ms/step - loss: 0.0540 - sparse_categorical_accuracy: 0.9856
[0.05398287624120712, 0.985595703125]



Merging coefficients: (0.7368421052631579, 0.26315789473684215)
OG Task Score:
  accuracy: 0.9921875
SCOMP Task Score:
  accuracy: 0.984619140625


# [scomp, ft], with Fishers
# scores = get_scores([scomp_model, ft_model], [scomp_fisher, ft_fisher])
scores = [Score(coeffs=(1.0, 0.0), ft_acc=0.98681640625, scomp_acc=0.98333740234375), Score(coeffs=(0.9795918367346939, 0.020408163265306145), ft_acc=0.9874267578125, scomp_acc=0.98358154296875), Score(coeffs=(0.9591836734693877, 0.04081632653061229), ft_acc=0.9881591796875, scomp_acc=0.98382568359375), Score(coeffs=(0.9387755102040817, 0.061224489795918324), ft_acc=0.98883056640625, scomp_acc=0.9837646484375), Score(coeffs=(0.9183673469387755, 0.08163265306122447), ft_acc=0.9893798828125, scomp_acc=0.983642578125), Score(coeffs=(0.8979591836734694, 0.10204081632653061), ft_acc=0.9896240234375, scomp_acc=0.98388671875), Score(coeffs=(0.8775510204081632, 0.12244897959183676), ft_acc=0.9901123046875, scomp_acc=0.98394775390625), Score(coeffs=(0.8571428571428571, 0.1428571428571429), ft_acc=0.9903564453125, scomp_acc=0.984375), Score(coeffs=(0.8367346938775511, 0.16326530612244894), ft_acc=0.9906005859375, scomp_acc=0.9840087890625), Score(coeffs=(0.8163265306122449, 0.18367346938775508), ft_acc=0.9908447265625, scomp_acc=0.98394775390625), Score(coeffs=(0.7959183673469388, 0.20408163265306123), ft_acc=0.9910888671875, scomp_acc=0.98394775390625), Score(coeffs=(0.7755102040816326, 0.22448979591836737), ft_acc=0.991455078125, scomp_acc=0.98388671875), Score(coeffs=(0.7551020408163265, 0.24489795918367352), ft_acc=0.99176025390625, scomp_acc=0.9837646484375), Score(coeffs=(0.7346938775510204, 0.26530612244897955), ft_acc=0.99176025390625, scomp_acc=0.9835205078125), Score(coeffs=(0.7142857142857143, 0.2857142857142857), ft_acc=0.99200439453125, scomp_acc=0.9832763671875), Score(coeffs=(0.6938775510204082, 0.30612244897959184), ft_acc=0.992431640625, scomp_acc=0.98260498046875), Score(coeffs=(0.673469387755102, 0.326530612244898), ft_acc=0.99249267578125, scomp_acc=0.982421875), Score(coeffs=(0.6530612244897959, 0.34693877551020413), ft_acc=0.99267578125, scomp_acc=0.982177734375), Score(coeffs=(0.6326530612244898, 0.36734693877551017), ft_acc=0.99273681640625, scomp_acc=0.98193359375), Score(coeffs=(0.6122448979591837, 0.3877551020408163), ft_acc=0.9927978515625, scomp_acc=0.98187255859375), Score(coeffs=(0.5918367346938775, 0.40816326530612246), ft_acc=0.992919921875, scomp_acc=0.9815673828125), Score(coeffs=(0.5714285714285714, 0.4285714285714286), ft_acc=0.99285888671875, scomp_acc=0.9808349609375), Score(coeffs=(0.5510204081632653, 0.44897959183673475), ft_acc=0.99285888671875, scomp_acc=0.9801025390625), Score(coeffs=(0.5306122448979592, 0.4693877551020408), ft_acc=0.99273681640625, scomp_acc=0.9794921875), Score(coeffs=(0.5102040816326531, 0.4897959183673469), ft_acc=0.99267578125, scomp_acc=0.9786376953125), Score(coeffs=(0.4897959183673469, 0.5102040816326531), ft_acc=0.99285888671875, scomp_acc=0.97723388671875), Score(coeffs=(0.46938775510204084, 0.5306122448979591), ft_acc=0.99273681640625, scomp_acc=0.97576904296875), Score(coeffs=(0.4489795918367347, 0.5510204081632653), ft_acc=0.992919921875, scomp_acc=0.97332763671875), Score(coeffs=(0.42857142857142855, 0.5714285714285714), ft_acc=0.992919921875, scomp_acc=0.9703369140625), Score(coeffs=(0.40816326530612246, 0.5918367346938775), ft_acc=0.99273681640625, scomp_acc=0.96697998046875), Score(coeffs=(0.3877551020408163, 0.6122448979591837), ft_acc=0.9925537109375, scomp_acc=0.9622802734375), Score(coeffs=(0.3673469387755102, 0.6326530612244898), ft_acc=0.99249267578125, scomp_acc=0.95684814453125), Score(coeffs=(0.3469387755102041, 0.653061224489796), ft_acc=0.99249267578125, scomp_acc=0.94964599609375), Score(coeffs=(0.32653061224489793, 0.6734693877551021), ft_acc=0.99237060546875, scomp_acc=0.9400634765625), Score(coeffs=(0.30612244897959184, 0.6938775510204082), ft_acc=0.99249267578125, scomp_acc=0.93017578125), Score(coeffs=(0.2857142857142857, 0.7142857142857143), ft_acc=0.99249267578125, scomp_acc=0.91839599609375), Score(coeffs=(0.2653061224489796, 0.7346938775510203), ft_acc=0.992431640625, scomp_acc=0.90386962890625), Score(coeffs=(0.24489795918367346, 0.7551020408163265), ft_acc=0.99237060546875, scomp_acc=0.88812255859375), Score(coeffs=(0.22448979591836735, 0.7755102040816326), ft_acc=0.9923095703125, scomp_acc=0.86865234375), Score(coeffs=(0.20408163265306123, 0.7959183673469388), ft_acc=0.99224853515625, scomp_acc=0.85015869140625), Score(coeffs=(0.1836734693877551, 0.8163265306122449), ft_acc=0.9921875, scomp_acc=0.83154296875), Score(coeffs=(0.16326530612244897, 0.8367346938775511), ft_acc=0.9920654296875, scomp_acc=0.81427001953125), Score(coeffs=(0.14285714285714285, 0.8571428571428572), ft_acc=0.99188232421875, scomp_acc=0.79730224609375), Score(coeffs=(0.12244897959183673, 0.8775510204081632), ft_acc=0.99176025390625, scomp_acc=0.7811279296875), Score(coeffs=(0.10204081632653061, 0.8979591836734694), ft_acc=0.991455078125, scomp_acc=0.766845703125), Score(coeffs=(0.08163265306122448, 0.9183673469387755), ft_acc=0.99114990234375, scomp_acc=0.75836181640625), Score(coeffs=(0.061224489795918366, 0.9387755102040817), ft_acc=0.9906005859375, scomp_acc=0.74798583984375), Score(coeffs=(0.04081632653061224, 0.9591836734693877), ft_acc=0.99029541015625, scomp_acc=0.72369384765625), Score(coeffs=(0.02040816326530612, 0.9795918367346939), ft_acc=0.98883056640625, scomp_acc=0.716796875), Score(coeffs=(0.0, 1.0), ft_acc=0.41900634765625, scomp_acc=0.5302734375)]

# [scomp, ft], without Fishers
# scores2 = get_scores([scomp_model, ft_model], None)
scores2 = [Score(coeffs=(1.0, 0.0), ft_acc=0.98681640625, scomp_acc=0.98333740234375), Score(coeffs=(0.9795918367346939, 0.020408163265306145), ft_acc=0.9849853515625, scomp_acc=0.9827880859375), Score(coeffs=(0.9591836734693877, 0.04081632653061229), ft_acc=0.8056640625, scomp_acc=0.980224609375), Score(coeffs=(0.9387755102040817, 0.061224489795918324), ft_acc=0.75390625, scomp_acc=0.97882080078125), Score(coeffs=(0.9183673469387755, 0.08163265306122447), ft_acc=0.73040771484375, scomp_acc=0.97698974609375), Score(coeffs=(0.8979591836734694, 0.10204081632653061), ft_acc=0.7098388671875, scomp_acc=0.9744873046875), Score(coeffs=(0.8775510204081632, 0.12244897959183676), ft_acc=0.6893310546875, scomp_acc=0.9722900390625), Score(coeffs=(0.8571428571428571, 0.1428571428571429), ft_acc=0.66748046875, scomp_acc=0.96856689453125), Score(coeffs=(0.8367346938775511, 0.16326530612244894), ft_acc=0.6435546875, scomp_acc=0.962890625), Score(coeffs=(0.8163265306122449, 0.18367346938775508), ft_acc=0.616455078125, scomp_acc=0.9521484375), Score(coeffs=(0.7959183673469388, 0.20408163265306123), ft_acc=0.5806884765625, scomp_acc=0.93353271484375), Score(coeffs=(0.7755102040816326, 0.22448979591836737), ft_acc=0.53497314453125, scomp_acc=0.90716552734375), Score(coeffs=(0.7551020408163265, 0.24489795918367352), ft_acc=0.466552734375, scomp_acc=0.6209716796875), Score(coeffs=(0.7346938775510204, 0.26530612244897955), ft_acc=0.41900634765625, scomp_acc=0.5306396484375), Score(coeffs=(0.7142857142857143, 0.2857142857142857), ft_acc=0.4190673828125, scomp_acc=0.53070068359375), Score(coeffs=(0.6938775510204082, 0.30612244897959184), ft_acc=0.4320068359375, scomp_acc=0.55853271484375), Score(coeffs=(0.673469387755102, 0.326530612244898), ft_acc=0.4727783203125, scomp_acc=0.6751708984375), Score(coeffs=(0.6530612244897959, 0.34693877551020413), ft_acc=0.485595703125, scomp_acc=0.84912109375), Score(coeffs=(0.6326530612244898, 0.36734693877551017), ft_acc=0.48101806640625, scomp_acc=0.92974853515625), Score(coeffs=(0.6122448979591837, 0.3877551020408163), ft_acc=0.47161865234375, scomp_acc=0.92242431640625), Score(coeffs=(0.5918367346938775, 0.40816326530612246), ft_acc=0.46551513671875, scomp_acc=0.92047119140625), Score(coeffs=(0.5714285714285714, 0.4285714285714286), ft_acc=0.4658203125, scomp_acc=0.9246826171875), Score(coeffs=(0.5510204081632653, 0.44897959183673475), ft_acc=0.4249267578125, scomp_acc=0.9075927734375), Score(coeffs=(0.5306122448979592, 0.4693877551020408), ft_acc=0.41766357421875, scomp_acc=0.57781982421875), Score(coeffs=(0.5102040816326531, 0.4897959183673469), ft_acc=0.4385986328125, scomp_acc=0.47589111328125), Score(coeffs=(0.4897959183673469, 0.5102040816326531), ft_acc=0.517822265625, scomp_acc=0.467529296875), Score(coeffs=(0.46938775510204084, 0.5306122448979591), ft_acc=0.5609130859375, scomp_acc=0.46868896484375), Score(coeffs=(0.4489795918367347, 0.5510204081632653), ft_acc=0.57391357421875, scomp_acc=0.46929931640625), Score(coeffs=(0.42857142857142855, 0.5714285714285714), ft_acc=0.57806396484375, scomp_acc=0.46954345703125), Score(coeffs=(0.40816326530612246, 0.5918367346938775), ft_acc=0.5794677734375, scomp_acc=0.46942138671875), Score(coeffs=(0.3877551020408163, 0.6122448979591837), ft_acc=0.5806884765625, scomp_acc=0.46307373046875), Score(coeffs=(0.3673469387755102, 0.6326530612244898), ft_acc=0.60845947265625, scomp_acc=0.47747802734375), Score(coeffs=(0.3469387755102041, 0.653061224489796), ft_acc=0.79150390625, scomp_acc=0.68231201171875), Score(coeffs=(0.32653061224489793, 0.6734693877551021), ft_acc=0.9219970703125, scomp_acc=0.82147216796875), Score(coeffs=(0.30612244897959184, 0.6938775510204082), ft_acc=0.981201171875, scomp_acc=0.82000732421875), Score(coeffs=(0.2857142857142857, 0.7142857142857143), ft_acc=0.98944091796875, scomp_acc=0.80426025390625), Score(coeffs=(0.2653061224489796, 0.7346938775510203), ft_acc=0.98974609375, scomp_acc=0.792724609375), Score(coeffs=(0.24489795918367346, 0.7551020408163265), ft_acc=0.9896240234375, scomp_acc=0.7813720703125), Score(coeffs=(0.22448979591836735, 0.7755102040816326), ft_acc=0.98992919921875, scomp_acc=0.77264404296875), Score(coeffs=(0.20408163265306123, 0.7959183673469388), ft_acc=0.9898681640625, scomp_acc=0.7659912109375), Score(coeffs=(0.1836734693877551, 0.8163265306122449), ft_acc=0.9898681640625, scomp_acc=0.7569580078125), Score(coeffs=(0.16326530612244897, 0.8367346938775511), ft_acc=0.98980712890625, scomp_acc=0.74871826171875), Score(coeffs=(0.14285714285714285, 0.8571428571428572), ft_acc=0.98980712890625, scomp_acc=0.73834228515625), Score(coeffs=(0.12244897959183673, 0.8775510204081632), ft_acc=0.9896240234375, scomp_acc=0.7269287109375), Score(coeffs=(0.10204081632653061, 0.8979591836734694), ft_acc=0.9898681640625, scomp_acc=0.71337890625), Score(coeffs=(0.08163265306122448, 0.9183673469387755), ft_acc=0.98980712890625, scomp_acc=0.69915771484375), Score(coeffs=(0.061224489795918366, 0.9387755102040817), ft_acc=0.9898681640625, scomp_acc=0.68267822265625), Score(coeffs=(0.04081632653061224, 0.9591836734693877), ft_acc=0.9901123046875, scomp_acc=0.6651611328125), Score(coeffs=(0.02040816326530612, 0.9795918367346939), ft_acc=0.99053955078125, scomp_acc=0.647216796875), Score(coeffs=(0.0, 1.0), ft_acc=0.9921875, scomp_acc=0.5640869140625)]

"""
