R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/pi/scitail_ablate_baselines_01.py
"""

from importlib import reload
import random
import os

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

MODEL = "textattack/bert-base-uncased-RTE"

PEF_FILENAME = "bert_base_rte.sci_tail_train.all_vars.all_ex.65536.h5"
NMF_FILENAME = f"spH.nmf_decomp.c{1024}_2kIters_{65536}pe_mvpp{8}.{PEF_FILENAME}"

FISHER_FILENAME = "bert_base_rte.scit_tail_final_train.all_vars.all_ex.h5"

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################


def get_fisher():
    return diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME))

##########################################################################


hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
)
mc = hacc.make_model_context('a')

##########################################################################
gen_eval_ctx = scitail_ablations.get_scitail_eval_ctx(mc)
rte_eval_ctx = scitail_ablations.get_rte_eval_ctx(mc)
##########################################################################

"""
 81 looks like it significantly improves the generalization performance! Only for n_SG = 12

 Sort by average coeff of top-k examples (k=12), look at ablating those components first.


- Try to make fisher-aware gradient following ablation strategy.
- Try only using incorrectly labeled examples in the sign guide, perhaps taking into
  account the logits to make this a more continuous criteria.
- Need to experiment with various values of delta and N_SG.
- Try to have W coeffs weigh the examples when computing the sign guide.
- Baseline of random examples, probably set the fraction of correct predictions.
- Maybe see which generalizing set examples change the most.

"""

# COMP_INDEX = 824
# COMP_INDEX = 660
# COMP_INDEX = 35
# COMP_INDEX = 50
# COMP_INDEX = 71


COMP_INDEX = 81   # big val loss drop with N_SG=12
# COMP_INDEX = 238  # small, possibly trivial val loss drop with N_SG=12
# COMP_INDEX = 85   # small, possibly trivial val loss dropwith N_SG=12
# COMP_INDEX = 94   # slight, maybe trivial val loss drop with N_SG=12
# COMP_INDEX = 121  # maybe slight, maybe trivial val loss drop with N_SG=12 and delta=-2e-5
# COMP_INDEX = 127  # maybe slight, maybe trivial val loss drop with N_SG=12 
# COMP_INDEX = 227  # maybe slight, maybe trivial val loss drop with N_SG=12 
# COMP_INDEX = 887  # significant val loss drop with N_SG=48, small with N_SG=24,75


# COMP_INDEX = 50


# COMP_INDEX = 69

exp_config = BAE.AblationExperimentConfig(
    component_index=COMP_INDEX,
    n_coefficients=21,
)
exp = BAE.AblationExperiment(
    config=exp_config,
    mc=mc,
    dense_fisher=get_fisher().fishers,
    extra_eval_contexts=[],
)

true_sorted_example_indices = exp.sorted_example_indices

##########################################################################

predictions = np.argmax(mc.container.predicted_logits, axis=-1)
labels = mc.container.labels

incorrect_indicator = predictions != labels

label0_indicator = labels == 0
label1_indicator = labels == 1

preds0_indicator = predictions == 0
preds1_indicator = predictions == 1

##########################################################################

N_SAME_EVAL = 3000
eval_inds = np.arange(N_SAME_EVAL)


# N_EXAMPLES_SG = 4
# N_EXAMPLES_SG = 8
N_EXAMPLES_SG = 12
# N_EXAMPLES_SG = 16
# N_EXAMPLES_SG = 24
# N_EXAMPLES_SG = 48
# N_EXAMPLES_SG = 75

# exp.sorted_example_indices, = np.nonzero(np.argmax(mc.container.predicted_logits, axis=-1) != mc.container.labels)


def run_ablation(delta):
    sg_ds = exp.sign_guider.get_ds_for_examples(exp.sorted_example_indices[:N_EXAMPLES_SG])
    sign_guide = exp.sign_guider.compute_loss_gradient(
        sg_ds.batch(exp_config.compute_gradient_batch_size),
        normalize_gradients_by_example=False,
        # normalize_gradients_by_example=True,
    )
    gen = exp.stream_merge_based_ablation(
        retaining_fisher=exp.dense_fisher,
        #
        ablating_fisher=diagonal.compute_fisher_for_model(mc.model, sg_ds.batch(4), variables=mc.variables),
        # ablating_fisher=mc.make_fisher_for_components([COMP_INDEX]),
        #
        sign_guide=sign_guide,
        delta=delta,
    )
    scitail_val_kls = []
    example_kls = []
    for coefficients, output_model in gen:
        gen_eval_results = gen_eval_ctx.evaluate(output_model)
        sub_eval_results = exp.eval_context.evaluate(output_model, exp.sorted_example_indices[:N_EXAMPLES_SG])
        #
        scitail_val_kls.append(gen_eval_results.kl())
        example_kls.append(sub_eval_results.kl())
        print(coefficients)
        print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
        print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
        print('')
        #
    return scitail_val_kls, example_kls


# exp.sorted_example_indices = true_sorted_example_indices
# st_kls_comp, ex_kls_comp = run_ablation(
#     delta=-3e-5,
# )

# exp.sorted_example_indices = np.random.permutation(np.nonzero(label1_indicator & preds0_indicator)[0])
# st_kls_rand, ex_kls_rand = run_ablation(
#     delta=-3e-5,
# )

# exp.sorted_example_indices = np.random.permutation(np.nonzero(incorrect_indicator)[0])
# # exp.sorted_example_indices = np.random.permutation(np.nonzero(~incorrect_indicator)[0])
# st_kls_rand2, ex_kls_rand2 = run_ablation(
#     # delta=-3e-5,
#     delta=-8e-5,
# )

# exp.sorted_example_indices = np.random.permutation(np.arange(true_sorted_example_indices.shape[0]))
# st_kls_rand3, ex_kls_rand3 = run_ablation(
#     # delta=-3e-5,
#     delta=-8e-5,
# )


# plt.plot(st_kls_comp, ex_kls_comp, label='Component Examples')
# plt.plot(st_kls_rand, ex_kls_rand, label='Random Examples')
# plt.plot(st_kls_rand2, ex_kls_rand2, label='Random Examples')
# plt.plot(st_kls_rand3, ex_kls_rand3, label='Random Examples')
# plt.legend(loc='best')
# plt.show()



exp.sorted_example_indices = true_sorted_example_indices
# st_kls_comp, ex_kls_comp = run_ablation(
#     delta=-1.5e-4,
# )


# TODO: try using only incorrect exeamples in the sign guide


sg_ds = exp.sign_guider.get_ds_for_examples(exp.sorted_example_indices[:N_EXAMPLES_SG])
sign_guide = exp.sign_guider.compute_loss_gradient(
    sg_ds.batch(exp_config.compute_gradient_batch_size),
    normalize_gradients_by_example=False,
    # normalize_gradients_by_example=True,
)

# delta = -1e-4
# delta = -8e-5
# delta = -5e-5
# delta = -2e-5

# delta = -8e-3
delta = -3e-3
# delta = -2e-2
# delta = -1e-1

output_model = exp.mc.load_model()
output_variables = hf_util.get_all_variables(output_model)

# variables_to_merge = [exp.retaining_variables, exp.get_ablating_variables(sign_guide, delta)]
variables_to_merge = [exp.retaining_variables, exp.sign_guider.apply_gradient(exp.retaining_variables, sign_guide, delta)]


fishers_to_merge = [exp.dense_fisher, diagonal.compute_fisher_for_model(mc.model, sg_ds.batch(4), variables=mc.variables)]
# fishers_to_merge = [exp.dense_fisher, mc.make_fisher_for_components([COMP_INDEX])]

norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]

coeffs_set = merging.create_pairwise_grid_coeffs(exp.config.n_coefficients)
# coeffs_set = [(x / 10, y + 1.75) for x, y in coeffs_set[10:-1]]

for coefficients in coeffs_set:
    merging._merge_with_coeffs(
        output_variables,
        variables_to_merge,
        coefficients=coefficients,
        fishers=fishers_to_merge,
        # fisher_floor=exp.config.fisher_floor,
        fisher_floor=1e-9,
        favor_target_model=True,
        normalization_constants=norm_constants,
    )
    gen_eval_results = gen_eval_ctx.evaluate(output_model)
    sub_eval_results = exp.eval_context.evaluate(output_model, exp.sorted_example_indices[:N_EXAMPLES_SG])
    #
    print(coefficients)
    print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
    print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
    print('')


#
# delta = -2e-5
output_model.set_weights(variables_to_merge[-1])
gen_eval_results = gen_eval_ctx.evaluate(output_model)
sub_eval_results = exp.eval_context.evaluate(output_model, exp.sorted_example_indices[:N_EXAMPLES_SG])
#
print(coefficients)
print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
print('')
#


#
