R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/pi/paws_ablate_01.py
"""

from importlib import reload
import random
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import qqp_components_context as QAE
from em.projects.pi import qqp_merging_context as QMC


"""
Seems to be a lot of examples with similar words, combined with the 16+ values needed for
parameters to be taken into account, this seems to have biased components to being selective
for such groups of examples.

Maybe longer sequence length.

- Maybe try ranking components by their overall "mass", i.e. their average normalized coefficient value.
- Maybe compare the performance of any ablation method (incl gradient following and top examples pefs) using
  the top-k examples for a components vs k random examples (maybe with same ratio of correct and incorrect predictions)


- There are two main "outputs" of this methof:
    - The component groupings (i.e. the W)
    - The parameter space representations (i.e. the H)
- Both can be useful on their own.

- See if there is something slowing down the cuda NMF implementation.


94 0.1382 0.0963
98?
102
104?
125?
128?
138 0.1585 0.1029
139?
154
170  [seems similar to the HANS sub-obj swapping]
177?
183 0.5973 0.4158
"""


EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "textattack/bert-base-uncased-QQP"

PEF_FILENAME = "bert_base_qqp.paws_final_train.all_vars.all_ex.131072.h5"
NMF_FILENAME = "spH.nmf_decomp.c1024_2kIters_65536pe_mvpp16.bert_base_qqp.paws_final_train.all_vars.all_ex.131072.h5"

FISHER_FILENAME = "bert_base_qqp.paws_final_train.all_vars.all_ex.h5"

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################


def get_fisher():
    return diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME))

##########################################################################


hacc = QAE.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
)

mc = hacc.make_model_context('a')

eval_ctx = mc.get_evaluation_context()

sign_guider = QMC.SignGuider(mc=mc)

############################################################

N_GEN_EVAL = 3000

gen_eval_ctx = QAE.EvaluationContext2.create_from_ds(
    ds=em_datasets.load(
        'paws/final',
        split='validation',
        sequence_length=64,
        tokenizer=tokenizer,
    ).take(N_GEN_EVAL).cache(),
    model=mc.model,
)
############################################################

# N_EXAMPLES_SG = 12
N_EXAMPLES_SG = 24

GET_LOSS_GRADIENT_BATCH_SIZE = 16


COMP_INDEX = 170
# COMP_INDEX = 183


sorted_example_indices = mc.sort_example_indices_for_component(COMP_INDEX)

sg_ds = sign_guider.get_ds_for_examples(sorted_example_indices[:N_EXAMPLES_SG])


# TODO: Maybe weigh by coefficient after normalization?
sign_guide = sign_guider.compute_loss_gradient(
    sg_ds.batch(GET_LOSS_GRADIENT_BATCH_SIZE),
    # normalize_gradients_by_example=False,
    normalize_gradients_by_example=True,
)

############################################################

fisher1 = get_fisher().fishers
fisher2 = mc.make_fisher_for_components([COMP_INDEX])
# fisher2 = diagonal.compute_fisher_for_model(mc.model, sg_ds.batch(4), variables=mc.variables)

variables1 = list(mc.variables)

# delta = -6e-2
delta = -1e-2
# delta = -2e-1
# delta = -1e-4
variables2 = sign_guider.apply_sign_guide(mc.variables, sign_guide, delta)

############################################################

N_SAME_EVAL = 3000
eval_inds = np.arange(N_SAME_EVAL)


output_model = mc.load_model()
output_variables = hf_util.get_all_variables(output_model)

variables_to_merge = [variables1, variables2]
fishers_to_merge = [fisher1, fisher2]

norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

for coefficients in merging.create_pairwise_grid_coeffs(21):
    merging._merge_with_coeffs(
        output_variables,
        variables_to_merge,
        coefficients=coefficients,
        fishers=fishers_to_merge,
        fisher_floor=1e-7,
        favor_target_model=True,
        normalization_constants=norm_constants,
    )
    eval_results = eval_ctx.evaluate(output_model, eval_inds)
    gen_eval_results = gen_eval_ctx.evaluate(output_model)
    #
    print(coefficients)
    #
    print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
    print('')
    #
    print(f'  {eval_results.kl()}, {eval_results.acc():.4f}, {eval_results.loss()}')
    print('')
    #
    for i in [12, 24, 48]:
        sub_eval_results = eval_ctx.evaluate(output_model, sorted_example_indices[:i])
        print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
    print('')


#


#

output_model = mc.load_model()
output_variables = hf_util.get_all_variables(output_model)

# Test random examples vs top examples for subset.
subset_loss_gradient = list(sign_guider.compute_loss_gradient(
    sg_ds.batch(GET_LOSS_GRADIENT_BATCH_SIZE),
    # sign_guider.get_ds_for_examples(1000 + np.arange(N_EXAMPLES_SG)).batch(GET_LOSS_GRADIENT_BATCH_SIZE),
    normalize_gradients_by_example=False,
    # normalize_gradients_by_example=True,
))
subset_loss_gradient_mag = merging._l2_norm_of_fisher(subset_loss_gradient)
subset_loss_gradient = [g / subset_loss_gradient_mag for g in subset_loss_gradient]

# delta = -1e-1
delta = -3e-1
# delta = 1e-1

for coefficients in merging.create_pairwise_grid_coeffs(21):
    _, coeff = coefficients
    #
    for outv, ogv, grad in zip(output_variables, variables1, subset_loss_gradient):
        # Need the underscore to prevent every variable from being printed.
        _ = outv.assign(ogv + delta * coeff * grad)
    #
    eval_results = eval_ctx.evaluate(output_model, eval_inds)
    gen_eval_results = gen_eval_ctx.evaluate(output_model)
    #
    print(coefficients)
    #
    print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
    print('')
    #
    print(f'  {eval_results.kl()}, {eval_results.acc():.4f}, {eval_results.loss()}')
    print('')
    #
    for i in [12, 24, 48]:
        sub_eval_results = eval_ctx.evaluate(output_model, sorted_example_indices[:i])
        print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
    print('')


#


#
