R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i em/projects/pi/exps/mains/snli_tuned_ablations_02.py
"""
import dataclasses
from importlib import reload
import random
import os

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL_NUMBER = 0

MODEL = f"connectivity/feather_berts_{MODEL_NUMBER}"

# PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_train.all_vars.100000ex.131072.h5"
# NMF_FILENAME = f"spH.nmf_decomp.c{1024}_1250Iters_{65536}pe_mvpp{16}_{50000}ex.{PEF_FILENAME}"

PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.snli_train.all_vars.50000ex.65536.h5"
NMF_FILENAME = f"spH.nmf_decomp.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"

FISHER_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_snli_train.all_vars.{50000}ex.h5"

##########################################################################

hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
    special_processing='HF_MNLI',
)
mc = hacc.make_model_context('a')

exp = ablation_exp_util.Experiment1(
    mc=mc,
    retaining_fisher=diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers,
)

##########################################################################

eval_ctx = mc.get_evaluation_context()

##########################################################################

# COMP_INDEX = 136
# COMP_INDEX = 439
# COMP_INDEX = 350
# COMP_INDEX = 252
# COMP_INDEX = 483
# COMP_INDEX = 91
COMP_INDEX = 156


# N_EVAL = 1024
N_EVAL = 2 * 1024
# N_EVAL = 8 * 1024

evaluation_example_inds = exp.random_example_indices(N_EVAL)

sign_guide = [tf.random.normal(v.shape) for v in exp.retaining_variables]

##########################################################################


def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n


def kl_fn(model):
    kl = eval_ctx.evaluate(model, evaluation_example_inds).kl()
    print(kl)
    return kl


reload(ablation_exp_util)

targeter = ablation_exp_util.KlRangeTargeter(
    exp=exp,
    sign_guide=sign_guide,
    kl_fn=kl_fn,
    # kl_range=[.075, .175],
    kl_range=[.15, .25],
    delta_mag_range=[1e-5, 3],
)

kl_model = targeter.search(COMP_INDEX, max_iters=25)


#


real_evaluation_example_inds = exp.random_example_indices(16 * 1024)

eval_results = eval_ctx.evaluate(kl_model, real_evaluation_example_inds)
inds_by_kl = eval_results.indices_ordered_by_kl()


coeffs = []
kls = []
for _ind in inds_by_kl:
    kl = eval_results.kl_for_examples([_ind])
    ind = real_evaluation_example_inds[_ind]
    #
    kls.append(kl)
    coeffs.append(exp.W[ind, COMP_INDEX])
    # coeffs.append(exp.W[ind, 352])

coeffs = np.array(coeffs)
kls = np.array(kls)

# inds = np.argsort(coeffs)
# kls = kls[inds]
# coeffs = coeffs[inds]

# plot_kls = moving_average(kls, 768)

# plt.plot(coeffs[768 - 1:], plot_kls, )
# plt.show()


plt.plot(coeffs, kls, '.')
plt.show()


n_bins = 10
max_coeff = coeffs.max()
inds_by_bins = []
mean_coeffs = []
for i in range(n_bins):
    bin_min_coeff = i * max_coeff / n_bins
    bin_max_coeff = (i + 1) * max_coeff / n_bins
    mean_coeffs.append((bin_max_coeff + bin_min_coeff) / 2)
    inds_by_bins.append(np.nonzero((bin_min_coeff <= coeffs) & (coeffs <= bin_max_coeff))[0])

mean_kls = []
for inds in inds_by_bins:
    mean_kls.append(kls[inds].mean())


plt.plot(mean_coeffs, mean_kls)
plt.show()


#


#


"""
Given the component tuning PDFs, we can say that the learned W's roughly reflect groups
of examples processed by the strategy/heuristic. These ablations in this file can be
used to demonstrate that the correspondings H's reflect parameters disproportionately
important to the W-based groupings of examples. Transitively, we can thus infer that
the H's reflect the parameters important to the strategy/heuristic.

The former (interpretation of groups of examples based on W) can be ascertained
qualatatively by inspection of the top examples. I think diagnostic datasets (and
maybe toy datasets) are pretty much the only way to have somewhat quantative results.

I should also do some theoretical investigation on very simple models and datasets.
For example, something like linear/logistic regression on a toy dataset with known
correct/incorrect heuristics, perhaps engineered to separate nicely in the PEFs.


* Baseline for latter is using one component's H but comparing another component's
  coefficients to the examples' KL.


* need to include these for W's learned on other examples given the frozen H
"""