R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/pi/scitail_ablate_01.py
"""


# Comp 35!
# 56, 71, 78, 95?, 103, 110, 127, 169?, 172, 190, 227


from importlib import reload
import random
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "textattack/bert-base-uncased-RTE"

PEF_FILENAME = "bert_base_rte.sci_tail_train.all_vars.all_ex.65536.h5"
NMF_FILENAME = f"spH.nmf_decomp.c{1024}_2kIters_{65536}pe_mvpp{8}.{PEF_FILENAME}"

FISHER_FILENAME = "bert_base_rte.scit_tail_final_train.all_vars.all_ex.h5"

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################


def get_fisher():
    return diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME))

##########################################################################

# - eval on the train set of scitail, the validation set of scitail, and RTE (probably validation set)
# - multiply/geometric average for intersection, add/average for union.
# - maybe try ablation of HANs component and transfer to sci_tail?


hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
)
mc = hacc.make_model_context('a')

##########################################################################

N_GEN_EVAL = 3000

gen_eval_ctx = QCC.EvaluationContext2.create_from_ds(
    ds=em_datasets.load(
        'sci_tail/default',
        split='validation',
        sequence_length=64,
        tokenizer=tokenizer,
    ).take(N_GEN_EVAL).cache(),
    model=mc.model,
)

rte_eval_ctx = QCC.EvaluationContext2.create_from_ds(
    ds=em_datasets.load(
        'glue/rte',
        split='validation',
        sequence_length=64,
        tokenizer=tokenizer,
    ).cache(),
    model=mc.model,
)
##########################################################################

# COMP_INDEX = 35
# COMP_INDEX = 56
# COMP_INDEX = 71
# COMP_INDEX = 227
# COMP_INDEX = 110
# COMP_INDEX = 172
# COMP_INDEX = 191
# COMP_INDEX = 466
# COMP_INDEX = 485
# COMP_INDEX = 539
# COMP_INDEX = 582
# COMP_INDEX = 607
# COMP_INDEX = 609
# COMP_INDEX = 628
# COMP_INDEX = 660
# COMP_INDEX = 669
# COMP_INDEX = 757
# COMP_INDEX = 758
# COMP_INDEX = 770
# COMP_INDEX = 782
COMP_INDEX = 785
# COMP_INDEX = 789


exp_config = BAE.AblationExperimentConfig(
    component_index=COMP_INDEX,
    n_coefficients=21,
)
exp = BAE.AblationExperiment(
    config=exp_config,
    mc=mc,
    dense_fisher=get_fisher().fishers,
    extra_eval_contexts=[],
)

##########################################################################

# N_EXAMPLES_SG = 12
N_EXAMPLES_SG = 24

sg_ds = exp.sign_guider.get_ds_for_examples(exp.sorted_example_indices[:N_EXAMPLES_SG])
# TODO: Maybe weigh by coefficient after normalization?
sign_guide = exp.sign_guider.compute_loss_gradient(
    sg_ds.batch(exp_config.compute_gradient_batch_size),
    normalize_gradients_by_example=False,
    # normalize_gradients_by_example=True,
)

##########################################################################


# Look into why it doesn't look like increasing delta doesn't have the effect
# of "increasing the range of the coefficients" when merging.
#
# Look into all of the examples who share a hypothesis (especially those that have components selective for a label).

N_SAME_EVAL = 3000
eval_inds = np.arange(N_SAME_EVAL)

# delta = -8e-4
delta = -2e-4
# delta = -4e-5
# delta = -2e-3
# delta = -7e-3


hyp = "Vertebrates reproduce sexually."
# h_ex_inds = np.array([i for i, e in enumerate(mc.container.examples) if e.hypothesis == hyp.lower()])
h_ex_inds = np.array([i for i, e in enumerate(mc.container.examples) if e.hypothesis == hyp.lower() and e.label_char == 'e'])


gen = exp.stream_merge_based_ablation(
    retaining_fisher=exp.dense_fisher,
    #
    ablating_fisher=mc.make_fisher_for_components([COMP_INDEX]),
    # ablating_fisher=diagonal.compute_fisher_for_model(mc.model, sg_ds.batch(4), variables=mc.variables),
    #
    sign_guide=sign_guide,
    delta=delta,
)
for coefficients, output_model in gen:
    eval_results = exp.eval_context.evaluate(output_model, eval_inds)
    gen_eval_results = gen_eval_ctx.evaluate(output_model)
    rte_eval_results = rte_eval_ctx.evaluate(output_model)
    #
    print(coefficients)
    #
    print(f'  {rte_eval_results.kl()}, {rte_eval_results.acc():.4f}, {rte_eval_results.loss()}')
    print('')
    #
    print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
    print(f'  {eval_results.kl()}, {eval_results.acc():.4f}, {eval_results.loss()}')
    print('')
    #
    for i in [12, 24, 48]:
        sub_eval_results = exp.eval_context.evaluate(output_model, exp.sorted_example_indices[:i])
        print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
    print('')
    #
    sub_eval_results = exp.eval_context.evaluate(output_model, h_ex_inds)
    print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
    print('')


#


#


#

###############################################################################
###############################################################################
###############################################################################
###############################################################################


"""
We have these components that appear selective for a hypothesis. The labels tend
to be almost all the same. See if correcting based on these components leads to 
the small subset with the opposite label being messed up or preserves them.
"""


examples = mc.container.examples

# hyp = "Mercury is the one metal that melts below room temperature."
# hyp = "Formic acid is found in the secretions of stinging ants."
# hyp = "An increase in the body's cholesterol levels can lead to your arteries filling with plaque."

# # 770
# hyp = "Humans possess a ( n ) endoskeleton."

# # 782
# hyp = "Adulthood is divided into three stages."

# 785
hyp = "Vertebrates reproduce sexually."

h_exs = [e for e in examples if e.hypothesis == hyp.lower()]

print([e.label_char for e in h_exs])


