R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i em/projects/pi/exps/mains/snli_guided_ablations_02.py
"""
import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import kl_targeting
from em.projects.pi.exps import guided_ablations

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL_NUMBER = 0

MODEL = f"connectivity/feather_berts_{MODEL_NUMBER}"

PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.snli_train.all_vars.50000ex.65536.h5"
# NMF_FILENAME = f"spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
NMF_FILENAME = f"refit_w.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"


# fit_w.skip50000.50000ex.65536vpe.
NMF_FILENAME = f"fit_w.skip50000.50000ex.65536vpe.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
PEF_FILENAME = f'feather_berts_{MODEL_NUMBER}.snli_train.all_vars.skip50000.250000ex.131072.h5'


FISHER_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_snli_train.all_vars.{50000}ex.h5"

##########################################################################

hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
    special_processing='HF_MNLI',
)
mc = hacc.make_model_context('a')

exp = ablation_exp_util.Experiment1(
    mc=mc,
    retaining_fisher=diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers,
)

eval_ctx = exp.mc.get_evaluation_context()

##########################################################################

n_all_eval = 8 * 1024


# COMP_INDEX = 203
# COMP_INDEX = 207
# COMP_INDEX = 393
COMP_INDEX = 330


helper = guided_ablations.ExperimentHelper1(
    exp=exp,
    component_index=COMP_INDEX,
    #
    n_selected_examples=128,
    #
    # kl_target_range=[.35, .45],
    kl_target_range=[.25, .35],
    #
    # ablating_variable_style="fixed_offset",
    ablating_variable_style="gradient",
)

# reload(guided_ablations);helper.__class__=guided_ablations.ExperimentHelper1


comp_ablator = helper.get_component_examples_ablator()
model = comp_ablator.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator.example_inds).kl()
all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
print(selected_kl, all_kl)

comp_ablator2 = helper.get_component_H_ablator()
model = comp_ablator2.find_model()
selected_kl = eval_ctx.evaluate(model, comp_ablator2.example_inds).kl()
all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
print(selected_kl, all_kl)

rand_ablator = helper.get_random_examples_ablator()
model = rand_ablator.find_model()
selected_kl = eval_ctx.evaluate(model, rand_ablator.example_inds).kl()
all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
print(selected_kl, all_kl)



