R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i em/projects/pi/exps/mains/snli_guided_ablations_01.py
"""
import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import kl_targeting

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL_NUMBER = 0

MODEL = f"connectivity/feather_berts_{MODEL_NUMBER}"

PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.snli_train.all_vars.50000ex.65536.h5"
# NMF_FILENAME = f"spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
NMF_FILENAME = f"refit_w.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"


# fit_w.skip50000.50000ex.65536vpe.
NMF_FILENAME = f"fit_w.skip50000.50000ex.65536vpe.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
PEF_FILENAME = f'feather_berts_{MODEL_NUMBER}.snli_train.all_vars.skip50000.250000ex.131072.h5'


FISHER_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_snli_train.all_vars.{50000}ex.h5"

##########################################################################

hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
    special_processing='HF_MNLI',
)
mc = hacc.make_model_context('a')

exp = ablation_exp_util.Experiment1(
    mc=mc,
    retaining_fisher=diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers,
)

eval_ctx = exp.mc.get_evaluation_context()

##########################################################################


def get_kl_fn(ablating_fisher, ablating_sign_guide, example_inds):
    #
    def kl_fn(delta: float, lmbda: float):
        model = exp.create_model(ablating_fisher, ablating_sign_guide, delta, lmbda)
        return eval_ctx.evaluate(model, example_inds).kl()
    #
    return kl_fn


def get_model(examples_fisher, examples_loss_grad, example_inds, kl_range=(.125, .175)):
    targeter = kl_targeting.GenericKlTargeter(
        kl_fn=get_kl_fn(examples_fisher, examples_loss_grad, example_inds),
        kl_range=kl_range,
        delta_mag_range=[1e-5, 3],
    )
    targeter.search(max_iters=25)
    return exp.output_model


##########################################################################

n_all_eval = 8 * 1024


## Mixed predictions:
# component_index = 111
# component_index = 207
# component_index = 378
# component_index = 32
# component_index = 393

## Entailment selective:
# component_index = 222
# component_index = 491

## Contradiction selective:
# component_index = 288
component_index = 203
# component_index = 259


## Neutral selective:
# component_index = 345  # (sort of, also some contradictions)
# component_index = 234
# component_index = 455
# component_index = 7
# component_index = 330  # A lot of the predictions appear wrong as well.

# n_examples = 32
# n_examples = 64
n_examples = 128
# n_examples = 256
# n_examples = 2048


# TODO:
# - Also do gradient of loss assuming prediction is the correct label. Or do gradient of kl with original logits.
# - Pick all coeffs above a value to determine the size of top/random example indices.

# "
# "
# " TODO: Use gradient of KL with original logits as the sign guide.
# "
# "
# "

# Try merge with gradient-based update.


top_example_inds = exp.get_top_example_indices(component_index)[:n_examples]
top_examples_fisher = exp.compute_fisher(top_example_inds)
# top_examples_loss_grad = exp.compute_loss_gradient(top_example_inds)
top_examples_loss_grad = exp.compute_kl_gradient(top_example_inds, allow_recompile=True)

random_example_inds = exp.random_example_indices(n_examples)
# random_example_inds = np.random.permutation(np.nonzero(exp.predictions == exp.predictions[top_example_inds[0]])[0])[:n_examples]
random_examples_fisher = exp.compute_fisher(random_example_inds)
# random_examples_loss_grad = exp.compute_loss_gradient(random_example_inds)
random_examples_loss_grad = exp.compute_kl_gradient(random_example_inds, allow_recompile=True)


# top_model = get_model(top_examples_fisher, top_examples_loss_grad, top_example_inds)
# top_model = get_model(top_examples_fisher, top_examples_loss_grad, top_example_inds, [.06, .1])
top_model = get_model(top_examples_fisher, top_examples_loss_grad, top_example_inds, [.35, .45])
# top_model = get_model(top_examples_fisher, [-g for g in top_examples_loss_grad], top_example_inds)
# top_model = get_model(top_examples_fisher, [-g for g in top_examples_loss_grad], top_example_inds, [.06, .1])
top_selected_kl = eval_ctx.evaluate(top_model, top_example_inds).kl()
top_all_kl = eval_ctx.evaluate(top_model, list(range(n_all_eval))).kl()
print(top_selected_kl, top_all_kl)


# random_model = get_model(random_examples_fisher, random_examples_loss_grad, random_example_inds)
# random_model = get_model(random_examples_fisher, random_examples_loss_grad, random_example_inds, [.06, .1])
random_model = get_model(random_examples_fisher, random_examples_loss_grad, random_example_inds, [.35, .45])
# random_model = get_model(random_examples_fisher, [-g for g in random_examples_loss_grad], random_example_inds)
# random_model = get_model(random_examples_fisher, [-g for g in random_examples_loss_grad], random_example_inds, [.06, .1])
random_selected_kl = eval_ctx.evaluate(random_model, random_example_inds).kl()
random_all_kl = eval_ctx.evaluate(random_model, list(range(n_all_eval))).kl()
print(random_selected_kl, random_all_kl)




# 
