R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i em/projects/pi/exps/mains/hans/hans_comp_analysis03.py
"""

import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util
from em.projects.pi.exps import guided_ablations
from em.projects.pi.exps import multi_comp_util

from em.projects.pi.exps import hans_stuff

# from em.projects.ll import hans_util
from em.projects.ll import hans_labeling

TopExampleIndices = multi_comp_util.TopExampleIndices

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
MODEL = "textattack/bert-base-uncased-RTE"

##########################################################################

PEF_FILENAME = 'bert_base_rte.lexical_overlap.validation.all_vars.all_ex.131072.h5'
NMF_FILENAME = f"spH.nmf_decomp.c{256}_{2500}Iters_{131072}pe_mvpp{1}.{PEF_FILENAME}"

# FISHER_FILENAME = "bert_base_rte.lexical_overlap.validation.all_vars.all_ex.h5"
# FISHER_FILENAME = "bert_base_rte.lexical_overlap_ye.validation.all_vars.all_ex.h5"
FISHER_FILENAME = "bert_base_rte.rte.train.all_vars.all_ex.h5"


##########################################################################

def read_in_to_exp() -> ablation_exp_util.Experiment1:
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
    hacc = QCC.QqpComponentContext(
        model_name_pattern=MODEL,
        pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
        nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
        tokenizer=tokenizer,
        # special_processing=FLAGS.exp_special_processing,
        special_processing=None,
    )
    retaining_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers
    return ablation_exp_util.Experiment1(
        mc=hacc.make_model_context('a'),
        retaining_fisher=retaining_fisher,
    )


def print_dict(d):
    for k in sorted(d.keys()):
        print(f'{k}: {d[k]}')
    print()


##########################################################################

exp = read_in_to_exp()
eval_ctx = exp.mc.get_evaluation_context()
nmf = exp.nmf

##########################################################################

# hans_examples = hans_stuff.get_hans_le_examples("validation")

# subcase_indicators = hans_labeling.compute_subcase_indicators(hans_examples)
# template_indicators = hans_labeling.compute_template_indicators(hans_examples)

heldout_ds = em_datasets.load(
    'hans/lexical_overlap',
    split='train',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    sequence_length=64)
heldout_ds = heldout_ds.batch(64)


##########################################################################

super_hans = hans_stuff.HansHelper1(
    exp=exp,
    split="validation",
)


tuning_infos = super_hans.compute_component_tuning_infos_by_subcase(12, .8)
for k, v in tuning_infos.items():
    print(f'{k}: {len(v)}')

print()


results = super_hans.evaluate(exp.mc.model)

print_dict(results.get_acc_by_subcase())
print()

##########################################################################

n_all_eval = 10_000

# component_index = 163
# component_index = 228
# component_index = 239
# component_index = 134
# component_index = 123
# component_index = 69
# component_index = 156

exp.mc.model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)


# n_ex_per_component = 8
n_ex_per_component = 6

# component_indices = [156, 69, 123]
# component_indices = [156, 69, 123, 134, 239, 228]
# component_indices = [225, 81]

# ln_passive: 2
# ln_subject/object_swap: 4
# ln_conjunction: 1
# le_relative_clause: 2


# component_indices = tuning_infos['ln_subject/object_swap']
# component_indices = tuning_infos['ln_passive']
# component_indices = tuning_infos['ln_conjunction']

component_indices = np.concatenate([v for k, v in tuning_infos.items() if k.startswith('ln_')])


top_ex_info = TopExampleIndices.select(nmf, component_indices, n_ex_per_component)

helper = multi_comp_util.ExperimentHelper1(
    exp=exp,
    top_examples_info=top_ex_info,
    #
    gradient_target='loss',
    negate_gradient=True,
    #
    # kl_target_range=[.45, .55],
    kl_target_range=[.25, .35],
    # kl_target_range=[.15, .25],
    #
    # ablating_variable_style="fixed_offset",
    ablating_variable_style="gradient",
)


# reload(guided_ablations);helper.__class__=guided_ablations.ExperimentHelper1

comp_ablator = helper.get_component_examples_ablator()
model = comp_ablator.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator.example_inds).kl()
results = eval_ctx.evaluate(model, list(range(n_all_eval)))
print(selected_kl, results.kl())
print(results.acc())

results = super_hans.evaluate(model)
print_dict(results.get_acc_by_subcase())
print_dict(results.get_kl_by_subcase())
print(model.evaluate(heldout_ds)[1])


comp_ablator2 = helper.get_component_H_ablator()
model = comp_ablator2.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator2.example_inds).kl()
results = eval_ctx.evaluate(model, list(range(n_all_eval)))
print(selected_kl, results.kl())
print(results.acc())

results = super_hans.evaluate(model)
print_dict(results.get_acc_by_subcase())
print_dict(results.get_kl_by_subcase())
print(model.evaluate(heldout_ds)[1])


# Try the auto-compute component tunings.


# '
# '
# '
# #
# # - Need new selection method (instead of KL-targeter) [maybe have fixed "budget" of number of predictions to change],
# #       or maybe just do something like a regular model merge.
# # - Do multiple component ablations.
# # - Auto select components based on tunings.
# # - See if the improvement generalize (i.e. evaluate on the train set (the one I didn't use for NPEFF)).
# # - Even if using random examples works better than comp stuff, see if better than plan loss gradient step.
# # - Trying using _ye only fishers.
# #
# '
# '
# '


##########################################################################
