R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i em/projects/pi/exps/mains/multi_comp/snli_multi_comp_01.py
"""

import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util
from em.projects.pi.exps import guided_ablations
from em.projects.pi.exps import multi_comp_util

TopExampleIndices = multi_comp_util.TopExampleIndices

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL_NUMBER = 0

MODEL = f"connectivity/feather_berts_{MODEL_NUMBER}"

PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.snli_train.all_vars.50000ex.65536.h5"
# NMF_FILENAME = f"spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
NMF_FILENAME = f"refit_w.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"


# fit_w.skip50000.50000ex.65536vpe.
NMF_FILENAME = f"fit_w.skip50000.50000ex.65536vpe.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
PEF_FILENAME = f'feather_berts_{MODEL_NUMBER}.snli_train.all_vars.skip50000.250000ex.131072.h5'


FISHER_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_snli_train.all_vars.{50000}ex.h5"

##########################################################################

hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
    special_processing='HF_MNLI',
)
mc = hacc.make_model_context('a')

exp = ablation_exp_util.Experiment1(
    mc=mc,
    retaining_fisher=diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers,
)
nmf = exp.nmf

eval_ctx = exp.mc.get_evaluation_context()

##########################################################################


# cos_sim_H = multi_comp_util.compute_H_cos_sim_matrix(nmf)
# cos_sim_W = multi_comp_util.compute_W_cos_sim_matrix(nmf)

# fig, axs = plt.subplots(1, 2)
# axs[0].imshow(
#     cos_sim_H,
#     vmin=0, vmax=1,
#     cmap=sns.color_palette("rocket", as_cmap=True),
# )
# axs[1].imshow(
#     cos_sim_W,
#     vmin=0, vmax=1,
#     cmap=sns.color_palette("rocket", as_cmap=True),
# )
# plt.show()

##########################################################################
"""
Possible component grouping criteria:
    - High avg coeff value.
    - Correlated/similar W.
    - Correlated/similar H.
    - Tuning to particular features:
        - Prediction
        - Some manually determined thing.
"""
##########################################################################
n_all_eval = 8 * 1024


def do_run(helper, ablator):
    model = ablator.find_model()
    #
    top_ex_info = helper.top_examples_info
    for comp_ind, ex_inds in zip(top_ex_info.component_indices, top_ex_info.example_indices_by_component):
        kl = eval_ctx.evaluate(model, ex_inds).kl()
        print(f'Comp {comp_ind}: {kl}')
    #
    all_comps_kl = eval_ctx.evaluate(model, ablator.example_inds).kl()
    print(f'Comp group: {all_comps_kl}')
    #
    all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
    print(f'General: {all_kl}')
    print()


##########################################################################

# cos_sim_W2 = cos_sim_W - 10 * np.eye(cos_sim_W.shape[0])

# q = np.argmax(cos_sim_W2.reshape([-1]))
# component_indices = [
#     q // cos_sim_W2.shape[-1],
#     q % cos_sim_W2.shape[-1],
# ]

##########################################################################

# n_ex_per_component = 32
n_ex_per_component = 64
# n_ex_per_component = 128

## Components with high avg coeff value.
component_indices = [446, 216, 136, 73, 474, 413, 156]
#
# component_indices = component_indices[:2]
# component_indices = component_indices[:3]
component_indices = component_indices[:5]


top_ex_info = TopExampleIndices.select(nmf, component_indices, n_ex_per_component)

helper = multi_comp_util.ExperimentHelper1(
    exp=exp,
    top_examples_info=top_ex_info,
    #
    kl_target_range=[.25, .35],
    #
    ablating_variable_style="fixed_offset",
    # ablating_variable_style="gradient",
)

do_run(helper, helper.get_component_examples_ablator())

do_run(helper, helper.get_component_H_ablator())


# comp_ablator = helper.get_component_examples_ablator()
# model = comp_ablator.find_model()
# selected_kl = eval_ctx.evaluate(model, comp_ablator.example_inds).kl()
# all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
# print(selected_kl, all_kl)

# comp_ablator2 = helper.get_component_H_ablator()
# model = comp_ablator2.find_model()
# selected_kl = eval_ctx.evaluate(model, comp_ablator2.example_inds).kl()
# all_kl = eval_ctx.evaluate(model, list(range(n_all_eval))).kl()
# print(selected_kl, all_kl)
