R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i em/projects/pi/exps/mains/guided_ablations/snli_guided_kl_ablation_output_01.py
"""

import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util

from em.projects.pi.exps import guided_ablations

OutputForAblator = guided_ablations.OutputForAblator
OutputForComponent = guided_ablations.OutputForComponent
KlSelectivityInfo = guided_ablations.KlSelectivityInfo

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL = "connectivity/feather_berts_0"

og_pef_name = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

##########################################################################

RESULTS_DIR = f'{EXPS_DIR}/guided_kl_ablations/attempt01'
RESULT_FILEPATH = os.path.join(RESULTS_DIR, 'test_guided_kl_ablation.comp{component_index}.h5')

##########################################################################

# Some selectivity info.
# Spearman correlations?

##########################################################################

# # COMP_INDEX = 69
# # COMP_INDEX = 207
# # COMP_INDEX = 330
# COMP_INDEX = 393

# results = OutputForComponent.load(RESULT_FILEPATH.format(component_index=COMP_INDEX))


# kl_info = KlSelectivityInfo.mean(results.get_component_top_fisher_kl_selectivities())
# kl_info.log()

# kl_info = KlSelectivityInfo.mean(results.get_component_H_kl_selectivities())
# kl_info.log()

# kl_info = KlSelectivityInfo.mean(results.get_random_examples_kl_selectivities())
# kl_info.log()


# print(results._compute_spearmanr(results.component_top_fisher_ablation, results.component_index))
# print(results._compute_spearmanr(results.component_H_ablation, results.component_index))


##########################################################################


# TODO: Filter out ablations where the KL didn't get into the range.

rows = []

for filename in os.listdir(RESULTS_DIR):
    print(filename)
    filepath = os.path.join(RESULTS_DIR, filename)
    results = OutputForComponent.load(filepath)
    # 
    comp_ex_kl_info = KlSelectivityInfo.mean(results.get_component_top_fisher_kl_selectivities())
    comp_H_kl_info = KlSelectivityInfo.mean(results.get_component_H_kl_selectivities())
    rand_ex_kl_info = KlSelectivityInfo.mean(results.get_random_examples_kl_selectivities())
    # 
    rows.append([
        results.component_index,
        # 
        comp_ex_kl_info.selected_examples_kl,
        comp_ex_kl_info.all_examples_kl,
        results._compute_spearmanr(results.component_top_fisher_ablation, results.component_index),
        # 
        comp_H_kl_info.selected_examples_kl,
        comp_H_kl_info.all_examples_kl,
        results._compute_spearmanr(results.component_H_ablation, results.component_index),
        # 
        rand_ex_kl_info.selected_examples_kl,
        rand_ex_kl_info.all_examples_kl,
    ])

rows = sorted(rows, key=lambda r: r[0])
for r in rows:
    print(', '.join(str(c) for c in r))



# Evaluate the ablations on another set of examples not used to fit the KL to (maybe a bit stronger (generalizing) results this way)
