R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/pi/qqp_ablate_01.py
"""

from importlib import reload
import random
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

# from em.projects.ll import hans_ablation_experiment as HAE
# from em.projects.ll import hans_components_context as HCC
# from em.projects.ll import hans_merging_context as HMC

from em.projects.pi import qqp_components_context as QAE
from em.projects.pi import qqp_merging_context as QMC


###############################################################################









"""

'
'
'
'
'
'
'
' NOTE: It looks like a decent amount of examples the model gets wrong are fairly
'        ambiguous and sometimes appear to be mis-labeled.
'
' TODO: FIND A BETTER TASK THAN QQP!!!
'  Options: WinoGrande? [20k train, 20k for fisher stuff], Binarized MNLI?, (Binarized?) SNLI?
'           sci_tail?, doc_nli? (looks like has long examples), civil_comments?, gap (seems like wino)
'           hellaswag?, wikipedia_toxicity_subtypes? paws_wiki? (seems like cleaner QQP)
'
' Maybe look at textual entailment tasks? Maybe frame some stuff as domain adaptation?
' Maybe some image stuff?
'
'
'
'
'
'
'
'
'


IDEA: FINE-TUNE QQP model on paws_wiki, Look into what the splits mean. It looks like cleaner
version of paraphrase detection.
        - Maybe train mostly on the noisily labeled (I think that is the unlabeled split), then use
          the clean train split for NMF stuff, then evaluate on validation set
        - Can also try zero shot transfer from the QQP model






"""
# raise Exception('TODO: READ ABOVE!!!!')














EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "textattack/bert-base-uncased-QQP"

# PEF_FILENAME = "bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
# NMF_FILENAME = f"spH.nmf_decomp.c1024_2kIters_65536pe_20000ex_mvpp8.{PEF_FILENAME}"

# PEF_FILENAME = "enriched_incorrects_to_0_25.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
# NMF_FILENAME = f"spH.nmf_decomp.c{1024}_2kIters_{65536}pe_mvpp{8}.{PEF_FILENAME}"

PEF_FILENAME = "enriched_incorrects_to_0_25.bert_base_qqp.qqp_val.all_vars.first_20k.131072.h5"
NMF_FILENAME = f"spH.nmf_decomp.c1024_2kIters_65536pe_mvpp8.{PEF_FILENAME}"


FISHER_FILENAME = "bert_base_qqp.qqp_val.all_vars.first_20k.h5"

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################


def get_fisher():
    return diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME))

##########################################################################


# NOTE: This only works when the fishers are the N in the dataset in order.


# N_EVAL_EXAMPLES = 2048
# N_EVAL_EXAMPLES = 20000


hacc = QAE.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
)

mc = hacc.make_model_context('a')

eval_ctx = mc.get_evaluation_context()

sign_guider = QMC.SignGuider(mc=mc)

############################################################

GEN_EVAL_OFFSET = 20_000
N_GEN_EVAL = 7500
gen_eval_ctx = QAE.EvaluationContext2.create_from_ds(
    ds=em_datasets.load(
        'glue/qqp',
        split='validation',
        sequence_length=64,
        tokenizer=tokenizer,
    ).skip(GEN_EVAL_OFFSET).take(N_GEN_EVAL).cache(),
    model=mc.model,
)

############################################################


# N_EXAMPLES_SG = 12
N_EXAMPLES_SG = 24
# N_EXAMPLES_SG = 48

GET_LOSS_GRADIENT_BATCH_SIZE = 16

# COMP_INDEX = 252
# COMP_INDEX = 231
# COMP_INDEX = 275
# COMP_INDEX = 239

# COMP_INDEX = 98
# COMP_INDEX = 42
# COMP_INDEX = 86
# COMP_INDEX = 74
# COMP_INDEX = 101
# COMP_INDEX = 11

# COMP_INDEX = 14
COMP_INDEX = 17

# COMP_INDEX = 18
# COMP_INDEX = 391
# COMP_INDEX = 403
# COMP_INDEX = 414
# COMP_INDEX = 419

sorted_example_indices = mc.sort_example_indices_for_component(COMP_INDEX)

sg_ds = sign_guider.get_ds_for_examples(sorted_example_indices[:N_EXAMPLES_SG])


# TODO: Maybe weigh by coefficient after normalization?
sign_guide = sign_guider.compute_loss_gradient(
    sg_ds.batch(GET_LOSS_GRADIENT_BATCH_SIZE),
    # normalize_gradients_by_example=False,
    normalize_gradients_by_example=True,
)

############################################################

fisher1 = get_fisher().fishers
fisher2 = mc.make_fisher_for_components([COMP_INDEX])
# fisher2 = diagonal.compute_fisher_for_model(mc.model, sg_ds.batch(4), variables=mc.variables)

variables1 = list(mc.variables)

# delta = 3.5e-4
# delta = 1e-3
# delta = 5e-3
# delta = -6e-3
# delta = -1e-4
# delta = -6e-4
# delta = -1e-3
# delta = 6e-4
delta = 6e-3
# delta = 2e-2
variables2 = sign_guider.apply_sign_guide(mc.variables, sign_guide, delta)

############################################################

output_model = mc.load_model()
output_variables = hf_util.get_all_variables(output_model)

variables_to_merge = [variables1, variables2]
fishers_to_merge = [fisher1, fisher2]

norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

for coefficients in merging.create_pairwise_grid_coeffs(21):
    merging._merge_with_coeffs(
        output_variables,
        variables_to_merge,
        coefficients=coefficients,
        fishers=fishers_to_merge,
        fisher_floor=1e-7,
        favor_target_model=True,
        normalization_constants=norm_constants,
    )
    eval_results = eval_ctx.evaluate(output_model)
    gen_eval_results = gen_eval_ctx.evaluate(output_model)
    #
    print(coefficients)
    #
    print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
    print('')
    #
    print(f'  {eval_results.kl()}, {eval_results.acc():.4f}, {eval_results.loss()}')
    print('')
    #
    for i in [12, 24, 48]:
        # print(f'{eval_results.loss_for_examples(ex_inds)}, {eval_results.acc_for_examples(ex_inds)}')
        print(f'  {eval_results.kl_for_examples(sorted_example_indices[:i])}, {eval_results.acc_for_examples(sorted_example_indices[:i])}')
    print('')


#

# '
# '
# '# TODO: Evaluate on (subset of) the last 20k QQP dev examples.
# '
# '
# ' If can improve performance, then compute sci-tail PEFs, some QQP test-set PEFs, and work on the
# ' compute general components on large dataset -> freeze and learn a few subset-specific components on subset,
# ' which is a (probably better) alternative to my example enrichment strategy.
# '

"""
- Can I do an "intersection" of different components and ablate that? See if there are
  any combinations of components whose intersection of top examples look tuned, especially for incorrect predictions.
- TODO: Look for components selective for incorrect predictions (maybe in a different file).
- Maybe the learn general components and then freeze and add a few and learn on a smaller dataset of interest.
- Maybe treat correctly and incorrectly tuned components different when making the sign-guide?
- Another idea: train components on the training dataset, then ablate those seem to be "overfits",
  look at loss and acc on the validation set.


221 -mixed 0.3036 0.1299
227 entail 0.2154 0.1669
231 entail 0.7556 0.5667
239 entail 0.3355 0.2363
244 entail 0.1646 0.1397
250 neutrl 0.1149 0.0773
252 entail 0.1638 0.1284

275 entail 0.6578 0.2679
276 entail 0.2199 0.1284
283 entail 0.2113 0.1465


# 252 looks like an interesting component to ablate.
# 250
# 245 maybe, low coeff values though.
# 244
# 239
# 231
# 227
# 221

# 268
# 275
# 276
# 283
# 287: Seems fun.
# 290

# 302
# 316
# 330
# 340

"""

"""
# NOTE: It looks like a decent amount of examples the model gets wrong are fairly
        ambiguous and sometimes appear to be mis-labeled.

# TODO: Find a better task
The 0_25

Maybe tuned for incorrect predictions components:
    - Mixed label tuning: 10, 11, 23, 28, 31, 33, 34, 35
                          18, 74
    - Sort of label-selective: 42, 86, 101
    - Fairly label-selective: 98

"""

