R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=2 python -i em/projects/pi/exps/mains/hans/hans_comp_analysis01.py
"""

import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util
from em.projects.pi.exps import guided_ablations
from em.projects.pi.exps import multi_comp_util

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
MODEL = "textattack/bert-base-uncased-RTE"

##########################################################################

PEF_FILENAME = 'bert_base_rte.lexical_overlap.validation.all_vars.all_ex.131072.h5'
NMF_FILENAME = f"spH.nmf_decomp.c{256}_{2500}Iters_{131072}pe_mvpp{1}.{PEF_FILENAME}"

FISHER_FILENAME = "bert_base_rte.lexical_overlap.validation.all_vars.all_ex.h5"


##########################################################################

def read_in_to_exp() -> ablation_exp_util.Experiment1:
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
    hacc = QCC.QqpComponentContext(
        model_name_pattern=MODEL,
        pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
        nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
        tokenizer=tokenizer,
        # special_processing=FLAGS.exp_special_processing,
        special_processing=None,
    )
    retaining_fisher = diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers
    return ablation_exp_util.Experiment1(
        mc=hacc.make_model_context('a'),
        retaining_fisher=retaining_fisher,
    )


##########################################################################

exp = read_in_to_exp()
eval_ctx = exp.mc.get_evaluation_context()

##########################################################################

n_all_eval = 10_000

# component_index = 0
# component_index = 111
# component_index = 103
# component_index = 89
# component_index = 86
component_index = 228

helper = guided_ablations.ExperimentHelper1(
    exp=exp,
    component_index=component_index,
    #
    # n_selected_examples=128,
    n_selected_examples=8,
    # n_selected_examples=16,
    #
    # kl_target_range=[.35, .45],
    kl_target_range=[.25, .35],
    #
    # ablating_variable_style="fixed_offset",
    ablating_variable_style="gradient",
)

# reload(guided_ablations);helper.__class__=guided_ablations.ExperimentHelper1


comp_ablator = helper.get_component_examples_ablator()
model = comp_ablator.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator.example_inds).kl()
all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
print(selected_kl, all_kl)

comp_ablator2 = helper.get_component_H_ablator()
model = comp_ablator2.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator2.example_inds).kl()
all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
print(selected_kl, all_kl)

# helper.resample_random_examples()
rand_ablator = helper.get_random_examples_ablator()
model = rand_ablator.find_model()
selected_kl = eval_ctx.evaluate(model, rand_ablator.example_inds).kl()
all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
print(selected_kl, all_kl)


# -Remember to try to use loss gradient instead of KL for some stuff.
# -Sort the pdf comps by avg top W or something
# -Look into dividing by HANS categories.
# -Maybe weight by coeff for gradients and/or fishers


##########################################################################


# # component_index = 0
# # component_index = 234
# # component_index = 226
# # component_index = 217
# # component_index = 212
# # component_index = 210
# # component_index = 209
# # component_index = 204
# component_index = 111

# # n_kl_range_targeter_examples = 128
# # n_kl_range_targeter_examples = 256
# n_kl_range_targeter_examples = 2048

# kl_range_targeter_ex_indices = np.argsort(-exp.nmf.W[:, component_index])[:n_kl_range_targeter_examples]

# helper = ablation_exp_util.ExperimentHelper1(
#     exp=exp,
#     component_index=component_index,
#     n_evaluation_examples=10_000,
#     kl_target_range=[0.1, 0.15],
#     # kl_target_range=[0.05, 0.1],
#     n_kl_range_targeter_examples=n_kl_range_targeter_examples,
#     kl_range_targeter_ex_indices=kl_range_targeter_ex_indices,
#     ablate_top_k_params=None,
#     # ablate_top_k_params=256,
#     fixed_sign_guide=False,
# )

# eval_results = helper.do_run(250)
# print(eval_results.kl_for_examples(kl_range_targeter_ex_indices), eval_results.kl())

# kls = tf.keras.losses.kl_divergence(tf.math.softmax(eval_results.logits), tf.math.softmax(eval_results.og_logits)).numpy()
# coeffs = exp.nmf.W[:, component_index]

# # kls = comp_results._compute_kl(np.concatenate([r.logits for r in comp_results.runs], axis=0))
# # coeffs = np.concatenate(len(comp_results.runs) * [comp_results.W[:, comp_results.component_index]], axis=0)
# # all_coeffs = np.concatenate(len(comp_results.runs) * [comp_results.W], axis=0)

# q = tf_metrics.spearmanr_vv(tf.cast(coeffs, tf.float32), tf.cast(kls, tf.float32))
# print(q.numpy())
