R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i em/projects/pi/exps/mains/trust_ver/snli_coeff_vs_fisher_norm_01.py
"""

import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util

RunOutput = coeff_kl_relationship_util.RunOutput
OutputForComponent = coeff_kl_relationship_util.OutputForComponent
MetricsInfo = coeff_kl_relationship_util.MetricsInfo

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL = "connectivity/feather_berts_0"

og_pef_name = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

##########################################################################

RESULTS_DIR = f'{EXPS_DIR}/coeff_kl_relationships/attempt01'
RESULT_FILEPATH = os.path.join(RESULTS_DIR, 'test_coeff_kl_relationship.comp{component_index}.h5')

##########################################################################


def load_fisher_norms(pef_path):
    pef = per_example.PerExampleFlatFishers.load(
        pef_path,
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    return pef.dense_fisher_norms


def load_W(nmf_path):
    nmf = nmf_common.SparseNmfDecomposition.load(nmf_path)
    nmf.normalize_components_to_unit_norm()
    return nmf.W


##########################################################################

def plot_W_vs_norm(W, norms, component_index):
    w = W[:, component_index]
    plt.plot(norms, w, '.')
    plt.show()


def compute_correlations(W, norms):
    ret = []
    for i in tqdm(range(W.shape[-1])):
        ret.append([
            i,
            stats.spearmanr(W[:, i], norms)[0],
            stats.pearsonr(W[:, i], norms)[0],
        ])
    return ret


def compute_ratios(W, norms, top_ks):
    norms = norms[:W.shape[0]]
    avg_norm = np.mean(norms)
    ret = []
    for i in tqdm(range(W.shape[-1])):
        sorted_inds = np.argsort(-W[:, i])
        row = [i]
        for k in top_ks:
            row.append(np.mean(norms[sorted_inds[:k]]) / avg_norm)
        ret.append(row)
    return ret


##########################################################################

all_fisher_norms = load_fisher_norms(os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_NAME))
W = load_W(os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_NAME))

##########################################################################

# TODO: Make sure this is right with how I select the fisher norms.
# plot_W_vs_norm(W, all_fisher_norms[:W.shape[0]], 112)


# # TODO: Look at the correlations
# corrs = compute_correlations(W, all_fisher_norms[:W.shape[0]])

# for r in corrs:
#     print(', '.join([str(c) for c in r]))

ratios = compute_ratios(W, all_fisher_norms[:W.shape[0]], [8, 16, 32, 64, 128, 256, 512])

for r in ratios:
    print(', '.join([str(c) for c in r]))

# TODO: Maybe also look into ratio of top-k example's average fisher norm to average fisher norm
