R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/ll/ablate/ablate_01.py

"""
from importlib import reload
import os

import numpy as np
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.util import hf_util

from em.projects.wino import nmf_components_fisher as ncf

from em.projects.ll import hans_components_context as HCC
from em.projects.ll import hans_merging_context as HMC

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PEF_FILEPATH = "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.10k.131072.h5"
NMF_FILEPATH = "spH.nmf_decomp.c1024_2kIters_{n_fisher_values}pe." + PEF_FILEPATH

FISHER_FILEPATH = "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.h5"

MODEL_PATTERN = 'connectivity/feather_berts_{model_number}'


PEF_FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILEPATH)
NMF_FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILEPATH)

FISHER_FILEPATH = os.path.join(FISHERS_DIR, FISHER_FILEPATH)

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

hacc = HCC.HansLoneComponentContext(
    model_name_pattern=MODEL_PATTERN,
    pef_filepath_pattern=PEF_FILEPATH,
    nmf_filepath_pattern=NMF_FILEPATH,
    with_flipped=True,
    tokenizer=tokenizer,
)


mc = hacc.make_model_context('a', model_number=0, n_fisher_values=65536)

eval_ctx = hacc.get_evaluation_context(og_logits=mc.container.predicted_logits)

mnli_eval_ctx = HMC.MnliEvaluationContext(
    n_examples=3072,
    tokenizer=tokenizer
)
mnli_eval_ctx.set_up_og_data(mc.model)


univ_example_subset_indices = [
    # Entailments.
    np.nonzero(mc.container.labels == 0)[0],
    # Non-entailments.
    np.nonzero(mc.container.labels == 1)[0],
]

############################################################

# COMP_INDEX = 11  # Has incorrect predictions.
# COMP_INDEX = 69
COMP_INDEX = 82
# COMP_INDEX = 97  # Has incorrect predictions.

sorted_example_indices = mc.sort_example_indices_for_component(COMP_INDEX)

if COMP_INDEX == 69:
    example_subset_indices = [
        # All contain the word "mentioned".
        sorted_example_indices[:5],
        # Up until the drop in coeff from 0.2462 to 0.1520.
        sorted_example_indices[:7],
        # Up until the prediction is incorrect.
        sorted_example_indices[:9],
    ]
elif COMP_INDEX in {11, 82, 97}:
    example_subset_indices = [
        sorted_example_indices[:12],
        sorted_example_indices[:24],
    ]
else:
    raise ValueError


############################################################

N_EXAMPLES_SG = 12

GET_LOSS_GRADIENT_BATCH_SIZE = 16

sg_ds = em_datasets.load('hans/lexical_overlap_ne_with_flipped', split='validation',
                         sequence_length=64, tokenizer=tokenizer)
sg_ds = HMC.get_ds_by_example_indices(sg_ds, sorted_example_indices[:N_EXAMPLES_SG])
sg_ds = sg_ds.batch(GET_LOSS_GRADIENT_BATCH_SIZE)

sign_guide = HMC.get_loss_gradient(
    mc.model,
    mc.variables,
    sg_ds,
)

############################################################

# fisher1 = mc.make_fisher_for_components(set(range(1024)) - {COMP_INDEX})
fisher2 = mc.make_fisher_for_components([COMP_INDEX])

fisher1 = diagonal.DiagonalFisher.load(
    os.path.join(FISHERS_DIR, "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.h5".format(model_number=0))
).fishers
# fisher2 = diagonal.compute_fisher_for_model(mc.model, sg_ds.unbatch().batch(4), variables=mc.variables)

variables1 = list(mc.variables)

# variables2 = [-w for w in mc.variables]
# variables2 = [0 * w for w in mc.variables]
# variables2 = [1e-1 + w for w in mc.variables]
# variables2 = [-1e-1 + w for w in mc.variables]

# delta = 1e-1
delta = 1e-3
# delta = 4e-3
# delta = -1e-3
# delta = -4e-3
variables2 = HMC.apply_sign_guide(mc.variables, sign_guide, delta)

############################################################

output_model = mc.load_model()
output_variables = hf_util.get_all_variables(output_model)

variables_to_merge = [variables1, variables2]
fishers_to_merge = [fisher1, fisher2]

norm_constants = [merging._l2_norm_of_fisher(f) for f in fishers_to_merge]


# for coefficients in merging.create_pairwise_grid_coeffs(21):
for coefficients in merging.create_pairwise_grid_coeffs(201)[:21]:
# for coefficients in merging.create_pairwise_grid_coeffs(76)[1:22]:
    merging._merge_with_coeffs(
        output_variables,
        variables_to_merge,
        coefficients=coefficients,
        fishers=fishers_to_merge,
        fisher_floor=1e-7,
        favor_target_model=True,
        normalization_constants=norm_constants,
    )
    eval_results = eval_ctx.evaluate(output_model)
    mnli_eval_results = mnli_eval_ctx.evaluate(output_model)
    #
    print(coefficients)
    #
    print(f'  {mnli_eval_results.kl()}, {mnli_eval_results.acc()}')
    print('')
    #
    # print(f'{eval_results.loss()}, {eval_results.acc()}')
    print(f'  {eval_results.kl()}, {eval_results.acc()}')
    print('')
    #
    for ex_inds in univ_example_subset_indices:
        print(f'  {eval_results.kl_for_examples(ex_inds)}, {eval_results.acc_for_examples(ex_inds)}')
    print('')
    #
    for ex_inds in example_subset_indices:
        # print(f'{eval_results.loss_for_examples(ex_inds)}, {eval_results.acc_for_examples(ex_inds)}')
        print(f'  {eval_results.kl_for_examples(ex_inds)}, {eval_results.acc_for_examples(ex_inds)}')
    print('')

############################################################

# output_model = mc.load_model()
# output_variables = hf_util.get_all_variables(output_model)

# subset_loss_gradient = list(sign_guide)
# subset_loss_gradient_mag = merging._l2_norm_of_fisher(subset_loss_gradient)
# subset_loss_gradient = [g / subset_loss_gradient_mag for g in subset_loss_gradient]

# # delta = -1e-1
# delta = 1e-1

# for coefficients in merging.create_pairwise_grid_coeffs(21):
#     _, coeff = coefficients
#     #
#     for outv, ogv, grad in zip(output_variables, variables1, subset_loss_gradient):
#         outv.assign(ogv + delta * coeff * grad)
#     #
#     eval_results = eval_ctx.evaluate(output_model)
#     mnli_eval_results = mnli_eval_ctx.evaluate(output_model)
#     #
#     print(coefficients)
#     #
#     print(f'  {mnli_eval_results.kl()}, {mnli_eval_results.acc()}')
#     print('')
#     #
#     # print(f'{eval_results.loss()}, {eval_results.acc()}')
#     print(f'  {eval_results.kl()}, {eval_results.acc()}')
#     print('')
#     #
#     for ex_inds in univ_example_subset_indices:
#         print(f'  {eval_results.kl_for_examples(ex_inds)}, {eval_results.acc_for_examples(ex_inds)}')
#     print('')
#     #
#     for ex_inds in example_subset_indices:
#         # print(f'{eval_results.loss_for_examples(ex_inds)}, {eval_results.acc_for_examples(ex_inds)}')
#         print(f'  {eval_results.kl_for_examples(ex_inds)}, {eval_results.acc_for_examples(ex_inds)}')
#     print('')

############################################################

"""
- See how much the model’s performance on MNLI is affected by the ablations/merges.


Baselines:
- Move in direction of postive/negative gradient for a component's top examples.
- Use the full fisher (or PEF/NMF-derived full fisher) for a component's top examples for the merge.
- Isotropic merge restricted to subset of parameters present in a component.
- Explore choices of the "other" Fisher used in the merge.
    - Full dataset fisher.
    - MNLI vs HANS vs combo fisher.
    - PEF-derived fisher.
    - NMF-reconstructed Fisher.
    - Whether to include the ablated component or not in the Fisher.
- Maybe something to do with the loss sign guide.

"""



# TODO: TRY TO LOOK AT KL-DIVERGENCE OF PREDICTIONS FOR DIFFERENT EXAMPLES, which might make
# the most theoretical sense for these types of experiments.


# TODO: The loss gradient sign to get direction of fixed offsets in each direction.





# sel_params = ncf.SelectionParameters(
#     coeff_factor=0.3,
#     frac_threshold=0.925,
#     p_value_threshold=0.15,
# )
# tuning_info = mc.compute_tuning_info(sel_params)
