R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i em/projects/pi/exps/mains/hans/hans_comp_analysis02.py
"""

import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util
from em.projects.pi.exps import guided_ablations
from em.projects.pi.exps import multi_comp_util

from em.projects.pi.exps import hans_stuff

# from em.projects.ll import hans_util
from em.projects.ll import hans_labeling

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
MODEL = "textattack/bert-base-uncased-RTE"

##########################################################################

PEF_FILENAME = 'bert_base_rte.lexical_overlap.validation.all_vars.all_ex.131072.h5'
NMF_FILENAME = f"spH.nmf_decomp.c{256}_{2500}Iters_{131072}pe_mvpp{1}.{PEF_FILENAME}"

FISHER_FILENAME = "bert_base_rte.lexical_overlap.validation.all_vars.all_ex.h5"


##########################################################################

def read_in_to_exp() -> ablation_exp_util.Experiment1:
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
    hacc = QCC.QqpComponentContext(
        model_name_pattern=MODEL,
        pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
        nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
        tokenizer=tokenizer,
        # special_processing=FLAGS.exp_special_processing,
        special_processing=None,
    )
    retaining_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers
    return ablation_exp_util.Experiment1(
        mc=hacc.make_model_context('a'),
        retaining_fisher=retaining_fisher,
    )


def print_dict(d):
    for k in sorted(d.keys()):
        print(f'{k}: {d[k]}')
    print()


##########################################################################

exp = read_in_to_exp()
eval_ctx = exp.mc.get_evaluation_context()

##########################################################################

# hans_examples = hans_stuff.get_hans_le_examples("validation")

# subcase_indicators = hans_labeling.compute_subcase_indicators(hans_examples)
# template_indicators = hans_labeling.compute_template_indicators(hans_examples)

heldout_ds = em_datasets.load(
    'hans/lexical_overlap',
    split='train',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    sequence_length=64)
heldout_ds = heldout_ds.batch(64)

rte_ds = em_datasets.load(
    'glue/rte',
    split='validation',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    sequence_length=64)
rte_ds = rte_ds.batch(64)


##########################################################################

super_hans = hans_stuff.HansHelper1(
    exp=exp,
    split="validation",
)

results = super_hans.evaluate(exp.mc.model)

print_dict(results.get_acc_by_subcase())

##########################################################################

n_all_eval = 10_000

# component_index = 163
# component_index = 228
# component_index = 239
# component_index = 134
# component_index = 123
# component_index = 69
# component_index = 156
component_index = 225

exp.mc.model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
print(exp.mc.model.evaluate(rte_ds)[1])

helper = guided_ablations.ExperimentHelper1(
    exp=exp,
    component_index=component_index,
    #
    random_example_selection='same_preds',
    # random_example_selection='uniform',
    #
    gradient_target='loss',
    negate_gradient=True,
    #
    # gradient_target='kl',
    #
    # n_selected_examples=128,
    # n_selected_examples=8,
    n_selected_examples=16,
    #
    # kl_target_range=[.75, 1.],
    # kl_target_range=[.35, .45],
    kl_target_range=[.25, .35],
    #
    # ablating_variable_style="fixed_offset",
    ablating_variable_style="gradient",
)

# reload(guided_ablations);helper.__class__=guided_ablations.ExperimentHelper1

comp_ablator = helper.get_component_examples_ablator()
model = comp_ablator.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator.example_inds).kl()
results = eval_ctx.evaluate(model, list(range(n_all_eval)))
print(selected_kl, results.kl())
print(results.acc())

results = super_hans.evaluate(model)
print_dict(results.get_acc_by_subcase())
print(model.evaluate(heldout_ds)[1])
print(model.evaluate(rte_ds)[1])


comp_ablator2 = helper.get_component_H_ablator()
model = comp_ablator2.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator2.example_inds).kl()
results = eval_ctx.evaluate(model, list(range(n_all_eval)))
print(selected_kl, results.kl())
print(results.acc())

results = super_hans.evaluate(model)
print_dict(results.get_acc_by_subcase())
print(model.evaluate(heldout_ds)[1])
print(model.evaluate(rte_ds)[1])


# helper.resample_random_examples()
rand_ablator = helper.get_random_examples_ablator()
model = rand_ablator.find_model()
selected_kl = eval_ctx.evaluate(model, rand_ablator.example_inds).kl()
results = eval_ctx.evaluate(model, list(range(n_all_eval)))
print(selected_kl, results.kl())
print(results.acc())

results = super_hans.evaluate(model)
print_dict(results.get_acc_by_subcase())
print(model.evaluate(heldout_ds)[1])
print(model.evaluate(rte_ds)[1])


# print(exp.mc.model.evaluate(heldout_ds)[1])

# '
# '
# '
# #
# # - Need new selection method (instead of KL-targeter) [maybe have fixed "budget" of number of predictions to change]
# # - Do multiple component ablations.
# # - Auto select components based on tunings.
# # - See if the improvement generalize (i.e. evaluate on the train set (the one I didn't use for NPEFF)).
# # - Even if using random examples works better than comp stuff, see if better than plan loss gradient step.
# #
# '
# '
# '



##########################################################################

# n_all_eval = 10_000

# # component_index = 0
# # component_index = 111
# # component_index = 103
# # component_index = 89
# # component_index = 86
# component_index = 228

# helper = guided_ablations.ExperimentHelper1(
#     exp=exp,
#     component_index=component_index,
#     #
#     # n_selected_examples=128,
#     n_selected_examples=8,
#     # n_selected_examples=16,
#     #
#     # kl_target_range=[.35, .45],
#     kl_target_range=[.25, .35],
#     #
#     # ablating_variable_style="fixed_offset",
#     ablating_variable_style="gradient",
# )

# # reload(guided_ablations);helper.__class__=guided_ablations.ExperimentHelper1


# comp_ablator = helper.get_component_examples_ablator()
# model = comp_ablator.find_model()
# selected_kl = eval_ctx.evaluate(model, comp_ablator.example_inds).kl()
# all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
# print(selected_kl, all_kl)

# comp_ablator2 = helper.get_component_H_ablator()
# model = comp_ablator2.find_model()
# selected_kl = eval_ctx.evaluate(model, comp_ablator2.example_inds).kl()
# all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
# print(selected_kl, all_kl)

# # helper.resample_random_examples()
# rand_ablator = helper.get_random_examples_ablator()
# model = rand_ablator.find_model()
# selected_kl = eval_ctx.evaluate(model, rand_ablator.example_inds).kl()
# all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
# print(selected_kl, all_kl)


# # -Remember to try to use loss gradient instead of KL for some stuff.
# # -Sort the pdf comps by avg top W or something
# # -Look into dividing by HANS categories.
# # -Maybe weight by coeff for gradients and/or fishers


##########################################################################


