R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i em/projects/pi/exps/mains/snli_01.py
"""
import dataclasses
from importlib import reload
import random
import os

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util


EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL_NUMBER = 0

MODEL = f"connectivity/feather_berts_{MODEL_NUMBER}"

# PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_train.all_vars.100000ex.131072.h5"
# NMF_FILENAME = f"spH.nmf_decomp.c{1024}_1250Iters_{65536}pe_mvpp{16}_{50000}ex.{PEF_FILENAME}"

PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.snli_train.all_vars.50000ex.65536.h5"
NMF_FILENAME = f"spH.nmf_decomp.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"

FISHER_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_snli_train.all_vars.{50000}ex.h5"

##########################################################################

hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
    special_processing='HF_MNLI',
)
mc = hacc.make_model_context('a')

exp = ablation_exp_util.Experiment1(
    mc=mc,
    retaining_fisher=diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers,
)

##########################################################################

ordered_comps = exp.get_components_order_by_coeff_magnitude(12)

##########################################################################
# N_EVAL = 2048
N_EVAL = 1024

gen_eval_ctx = exp.create_eval_ctx('snli/default', sequence_length=96, n_examples=N_EVAL)
# gen_eval_ctx = exp.create_eval_ctx('snli/default', sequence_length=96, n_examples=4096)
eval_ctx = mc.get_evaluation_context()


def print_evaluations(coefficients):
    gen_eval_results = gen_eval_ctx.evaluate(exp.output_model)
    sub_eval_results = eval_ctx.evaluate(exp.output_model, top_example_inds)
    #
    print(coefficients)
    print(f'  {gen_eval_results.kl()}, {gen_eval_results.acc():.4f}, {gen_eval_results.loss()}')
    print(f'  {sub_eval_results.kl()}, {sub_eval_results.acc()}')
    print('')


##########################################################################

# COMP_INDEX = 204  # Based on tuning, looks to be an "bad heuristic" component; maybe with positive delta and N_SG=128
# COMP_INDEX = 483  # hard maybe
# COMP_INDEX = 136  # hard maybe, also possibly "interpretable" component
# COMP_INDEX = 357  # hard maybe
# COMP_INDEX = 93  # hard maybe
# COMP_INDEX = 34  # hard maybe
# COMP_INDEX = 25  # Maybe with positive delta?
# COMP_INDEX = 348  # The top kl difference validation examples show some similar tuning to the component itself.
# COMP_INDEX = 13  # maybe
# COMP_INDEX = 397


"""
Components with random graidents but H ablating fisher:
439: selective for contradiction and probably mostly correct, consistently bad ablation.
136: maybe wrong-ish component, consistently good ablation.
483: consistenly mildly good ablation.
93: consistently OK/good ablation.
34: consistently OK/good ablation.
91: consistently OK/good ablation.
402: mixed ablation.
114: looks to be a probably correct component, bad or good ablation sometimesl.
144: consistently neutral to Good
481: Seems to be fairly consistently good.
"""

COMP_INDEX = 136

# N_SG = 12
# N_SG = 16
# N_SG = 32
# N_SG = 64
N_SG = 128
# N_SG = 256
# N_SG = 512
# N_SG = 1024

# delta = -1e1
# delta = -5e0
# delta = -1.2
delta = -6e-1
# delta = -2e-1  # ##
# delta = -6e-2
# delta = -8e-3
# delta = -3e-3

# delta = 6e-2
# delta = 2e-1


# top_example_inds = exp.get_top_example_indices(COMP_INDEX)[:N_SG]
# top_example_inds = exp.remove_correct_example_indices(exp.get_top_example_indices(COMP_INDEX)[:N_SG])
# top_example_inds = exp.get_top_incorrect_example_indices(COMP_INDEX)[:N_SG]
#
# top_example_inds = exp.random_example_indices_by_correctness(N_SG // 2, N_SG // 2)
# top_example_inds = exp.random_example_indices_by_correctness(2 * N_SG // 3, N_SG // 3)
top_example_inds = exp.random_example_indices_by_correctness(85 * N_SG // 100, 15 * N_SG // 100)
# top_example_inds = exp.random_example_indices_by_correctness(0, N_SG)
# top_example_inds = exp.random_example_indices(N_SG)


gradient = exp.compute_loss_gradient(top_example_inds)
# gradient = exp.compute_loss_gradient(top_example_inds, exp.W[top_example_inds, COMP_INDEX])
# gradient = exp.compute_loss_gradient(top_example_inds, exp.W[top_example_inds, COMP_INDEX], normalize_by_example=True)
#
gradient = exp.l2_normalize(gradient)


ablating_variables = exp.apply_gradient(gradient, delta)
# ablating_variables = exp.apply_sign_guide(gradient, delta)


ablating_fisher = mc.make_fisher_for_components([COMP_INDEX])
# ablating_fisher = exp.compute_fisher(top_example_inds)
# ablating_fisher = exp.compute_fisher(exp.random_example_indices(N_SG))
# ablating_fisher = exp.compute_fisher(top_example_inds, exp.W[top_example_inds, COMP_INDEX])
# ablating_fisher = exp.compute_fisher(top_example_inds, exp.W[top_example_inds, COMP_INDEX], normalize_by_example=True)
# ablating_fisher = exp.get_dummy_ablating_fisher()


# ablating_fisher = [tf.random.shuffle(f) for f in ablating_fisher]

#######################################

# print(exp.W[top_example_inds, COMP_INDEX])


gen = exp.stream_merge(
    ablating_variables=ablating_variables,
    ablating_fisher=ablating_fisher,
    coefficients=21,
    # coefficients=[(0.9, 0.1)]
    # coefficients=[(0.75, 0.25)]
    # coefficients=[(0.6, 0.4)],
    # coefficients=[(0.7, 0.3)],
    # coefficients=[(0.95, 0.05)]
    # coefficients=[(0.85, 0.15)]
    # coefficients=[(0.8, 0.1)]
    # coefficients=[(0.3, 0.7)]
    # coefficients=[(0.4, 0.6)]
    # coefficients=[(0.35, 0.65)]
)
for coefficients in gen:
    print_evaluations(coefficients)


#


#

gen_eval_results = gen_eval_ctx.evaluate(exp.output_model)
altered_preds_inds = gen_eval_results.indices_of_altered_predictions()
inds_by_kl = gen_eval_results.indices_ordered_by_kl()

for i in altered_preds_inds:
    input_ids = gen_eval_ctx.all_examples[0]['input_ids'][i]
    print(f'[{gen_eval_results.labels[i]}] {gen_eval_results.predictions[i]} <- {np.argmax(gen_eval_results.og_logits[i])}')
    print(tokenizer.decode(input_ids).replace(' [PAD]', ''))
    print('')


for i in inds_by_kl[:12]:
    input_ids = gen_eval_ctx.all_examples[0]['input_ids'][i]
    print(f'[{gen_eval_results.labels[i]}] {gen_eval_results.predictions[i]} <- {np.argmax(gen_eval_results.og_logits[i])} [{gen_eval_results.kl_for_examples([i])}]')
    print(tokenizer.decode(input_ids).replace(' [PAD]', ''))
    print('')

