R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i em/projects/pi/exps/mains/snli_tuned_ablations_05.py
"""
import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL_NUMBER = 0

MODEL = f"connectivity/feather_berts_{MODEL_NUMBER}"

PEF_FILENAME = f"feather_berts_{MODEL_NUMBER}.snli_train.all_vars.50000ex.65536.h5"
# NMF_FILENAME = f"spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
NMF_FILENAME = f"refit_w.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"


# fit_w.skip50000.50000ex.65536vpe.
NMF_FILENAME = f"fit_w.skip50000.50000ex.65536vpe.spH.nmf_decomp2.c{512}_1250Iters_{65536}pe_mvpp{10}_{50000}ex.{PEF_FILENAME}"
PEF_FILENAME = f'feather_berts_{MODEL_NUMBER}.snli_train.all_vars.skip50000.250000ex.131072.h5'


FISHER_FILENAME = f"feather_berts_{MODEL_NUMBER}.mnli_snli_train.all_vars.{50000}ex.h5"

##########################################################################

hacc = QCC.QqpComponentContext(
    model_name_pattern=MODEL,
    pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILENAME),
    nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILENAME),
    tokenizer=tokenizer,
    special_processing='HF_MNLI',
)
mc = hacc.make_model_context('a')

exp = ablation_exp_util.Experiment1(
    mc=mc,
    retaining_fisher=diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, FISHER_FILENAME)).fishers,
)


# exp.apply_sign_guide = exp.apply_gradient


##########################################################################

COMP_INDEX = 111
# COMP_INDEX = 222


reload(ablation_exp_util)
helper = ablation_exp_util.ExperimentHelper1(
    exp=exp,
    component_index=COMP_INDEX,
    n_evaluation_examples=16 * 1024,
    # n_evaluation_examples=50_000,
    kl_target_range=[.125, .175],
    # kl_target_range=[.025, .1],
    # kl_target_range=[1e-6, .01],
    #
    # n_kl_range_targeter_examples=2 * 1024,
    # kl_range_targeter_ex_indices=np.argsort(-exp.nmf.W[:, COMP_INDEX])[: 2 * 1024],
    # kl_range_targeter_ex_indices=np.argsort(-exp.nmf.W[:, COMP_INDEX])[: 1 * 1024],
    # kl_range_targeter_ex_indices=np.argsort(-exp.nmf.W[:, COMP_INDEX])[: 256],
    kl_range_targeter_ex_indices=np.argsort(-exp.nmf.W[:, COMP_INDEX])[: 128],
    # kl_range_targeter_ex_indices=np.argsort(-exp.nmf.W[:, COMP_INDEX])[: 4 * 1024],
    #
    # ablate_top_k_params=16 * 1024,
    # ablate_top_k_params=4 * 1024,
    # ablate_top_k_params=2 * 1024,
    ablate_top_k_params=256,
    #
    fixed_sign_guide=False,
)


##########################################################################

# for i in range(4):
for i in range(1):
    helper.do_run()
    print('Run', i)

# TODO: Add option for leaving the delta, lmbda constant and choosing
# random sign-guides with those fixed.


reload(ablation_exp_util);helper.__class__=ablation_exp_util.ExperimentHelper1
reload(ablation_exp_util);exp.__class__=ablation_exp_util.Experiment1


coeffs = helper.get_evaluation_examples_coeffs()
# coeffs = helper.get_evaluation_examples_coeffs(207)

kls = helper.get_examples_kl_matrix()

avg_kls = kls.mean(axis=-1)


# _ = helper.evaluate_with_different_sign_guide();

##########################################################################


plt.plot(coeffs, avg_kls, '.')
plt.show()

# plt.plot(coeffs, kls.max(axis=-1), '.')
# plt.show()


def plot_binned_coeffs_vs_kl(kls_vec, n_bins: int = 16, show: bool = True):
    # TODO: Something with bin min width.
    sorted_inds = np.argsort(coeffs)
    chunk_coeffs = []
    chunk_kls = []
    for i in range(n_bins):
        assert len(sorted_inds) % n_bins == 0, "TODO: Support this"
        chunk_size = len(sorted_inds) // n_bins
        inds = sorted_inds[i * chunk_size : (i + 1) * chunk_size]
        chunk_kls.append(kls_vec[inds].mean())
        chunk_coeffs.append(coeffs[inds].mean())
    #
    plt.plot(chunk_coeffs, chunk_kls)
    if show:
        plt.show()

plot_binned_coeffs_vs_kl(avg_kls)


for i in range(kls.shape[-1]):
    plot_binned_coeffs_vs_kl(kls[:, i], show=False)

plt.show()


##########################################################################

@dataclasses.dataclass
class ScoresInfo:
    mutual_info: float
    spearman: float
    pearson: float
    #
    def log(self):
        print(f'Mutual Info: {self.mutual_info}')
        print(f'Spearman Corr: {self.spearman}')
        print(f'Pearson Corr: {self.pearson}')
        print('')


def compute_scores(coeffs, kls):
    if len(kls.shape) == 2:
        coeffs = np.stack(kls.shape[1] * [coeffs], axis=1).reshape([-1])
        kls = kls.reshape([-1])
    return ScoresInfo(
        mutual_info=mutual_info_regression(coeffs[:, None], kls)[0],
        spearman=stats.spearmanr(coeffs, kls)[0],
        pearson=stats.pearsonr(coeffs, kls)[0],
    )


compute_scores(coeffs, avg_kls).log()
compute_scores(helper.get_evaluation_examples_coeffs(207), avg_kls).log()
compute_scores(helper.get_evaluation_examples_coeffs(209), avg_kls).log()

compute_scores(coeffs, kls).log()
compute_scores(helper.get_evaluation_examples_coeffs(209), kls).log()


rand_coeffs = coeffs.copy()
np.random.shuffle(rand_coeffs)
compute_scores(coeffs, avg_kls).log()
compute_scores(rand_coeffs, avg_kls).log()



##########################################################################



# (exp.pef.input_ids != 0).any(axis=0)
##########################################################################

# mutual_info_regression
# stats.spearmanr
# stats.pearsonr


# def moving_average(a, n=3) :
#     ret = np.cumsum(a, dtype=float)
#     ret[n:] = ret[n:] - ret[:-n]
#     return ret[n - 1:] / n


# def kl_fn(model):
#     kl = eval_ctx.evaluate(model, evaluation_example_inds).kl()
#     print(kl)
#     return kl


# reload(ablation_exp_util)

# targeter = ablation_exp_util.KlRangeTargeter(
#     exp=exp,
#     sign_guide=sign_guide,
#     kl_fn=kl_fn,
#     # kl_range=[.001, .01],
#     # kl_range=[.025, .075],
#     # kl_range=[.075, .175],
#     kl_range=[.15, .25],
#     delta_mag_range=[1e-5, 3],
# )

# kl_model = targeter.search(COMP_INDEX, max_iters=25)


# #


# real_evaluation_example_inds = exp.random_example_indices(16 * 1024)

# eval_results = eval_ctx.evaluate(kl_model, real_evaluation_example_inds)
# inds_by_kl = eval_results.indices_ordered_by_kl()


# coeffs = []
# kls = []
# for _ind in inds_by_kl:
#     kl = eval_results.kl_for_examples([_ind])
#     ind = real_evaluation_example_inds[_ind]
#     #
#     kls.append(kl)
#     coeffs.append(exp.W[ind, COMP_INDEX])
#     # coeffs.append(exp.W[ind, 352])

# coeffs = np.array(coeffs)
# kls = np.array(kls)

# # inds = np.argsort(coeffs)
# # kls = kls[inds]
# # coeffs = coeffs[inds]

# # plot_kls = moving_average(kls, 768)

# # plt.plot(coeffs[768 - 1:], plot_kls, )
# # plt.show()


# plt.plot(coeffs, kls, '.')
# plt.show()


# # n_bins = 10
# # max_coeff = coeffs.max()
# # inds_by_bins = []
# # mean_coeffs = []
# # for i in range(n_bins):
# #     bin_min_coeff = i * max_coeff / n_bins
# #     bin_max_coeff = (i + 1) * max_coeff / n_bins
# #     mean_coeffs.append((bin_max_coeff + bin_min_coeff) / 2)
# #     inds_by_bins.append(np.nonzero((bin_min_coeff <= coeffs) & (coeffs <= bin_max_coeff))[0])

# # mean_kls = []
# # for inds in inds_by_bins:
# #     mean_kls.append(kls[inds].mean())


# # plt.plot(mean_coeffs, mean_kls)
# # plt.show()



# # TODO: Something with bin min width.
# n_bins = 8
# # n_bins = 128
# sorted_inds = np.argsort(coeffs)
# mean_coeffs = []
# mean_kls = []
# for i in range(n_bins):
#     assert len(sorted_inds) % n_bins == 0, "TODO: Support this"
#     chunk_size = len(sorted_inds) // n_bins
#     inds = sorted_inds[i * chunk_size : (i + 1) * chunk_size]
#     mean_kls.append(kls[inds].mean())
#     mean_coeffs.append(coeffs[inds].mean())


# plt.plot(mean_coeffs, mean_kls)
# plt.show()




# # Fit the coeffs of the data used to learn the data to the NMF and
# # see if the tunings/ablations are sharper.
