# snli_coeff_kl_analysis_01.py
R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i em/projects/pi/exps/mains/snli_coeff_kl_analysis_01.py
"""
import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util

RunOutput = coeff_kl_relationship_util.RunOutput
OutputForComponent = coeff_kl_relationship_util.OutputForComponent
MetricsInfo = coeff_kl_relationship_util.MetricsInfo

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL = "connectivity/feather_berts_0"

og_pef_name = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

##########################################################################

RESULTS_DIR = f'{EXPS_DIR}/coeff_kl_relationships/attempt01'
RESULT_FILEPATH = os.path.join(RESULTS_DIR, 'test_coeff_kl_relationship.comp{component_index}.h5')

##########################################################################


def compute_mean_metrics_mag(results):
    pearson = 0
    spearman = 0
    all_metrics = []
    for i in range(results.W.shape[-1]):
        metrics = results.compute_metrics(i)
        all_metrics.append(metrics)
        pearson += np.abs(metrics.pearson)
        spearman += np.abs(metrics.spearman)
    return MetricsInfo(
        pearson=pearson / results.W.shape[-1],
        spearman=spearman / results.W.shape[-1],
    )


def compute_metrics_for_all_comps(results):
    return [results.compute_metrics(i) for i in range(results.W.shape[-1])]

##########################################################################


# # # COMP_INDEX = 474
# # # COMP_INDEX = 216
# # # COMP_INDEX = 288

# # COMP_INDEX = 69
# # COMP_INDEX = 111
# # COMP_INDEX = 333
COMP_INDEX = 222

comp_results = OutputForComponent.load(RESULT_FILEPATH.format(component_index=COMP_INDEX))

# scores_for_all_comps = compute_metrics_for_all_comps(comp_results)

# for i, score in enumerate(scores_for_all_comps):
#     print(f'{i}, {score.pearson}, {score.spearman}')


comp_results.compute_metrics(comp_results.component_index).log()
# # compute_mean_metrics_mag(comp_results).log()


kls = comp_results._compute_kl(np.concatenate([r.logits for r in comp_results.runs], axis=0))
coeffs = np.concatenate(len(comp_results.runs) * [comp_results.W[:, comp_results.component_index]], axis=0)
all_coeffs = np.concatenate(len(comp_results.runs) * [comp_results.W], axis=0)

q = tf_metrics.spearmanr_vv(tf.cast(coeffs, tf.float32), tf.cast(kls, tf.float32))
print(q.numpy())

w = tf_metrics.spearmanr_mv(tf.cast(all_coeffs, tf.float32), tf.cast(kls, tf.float32))
print(w.numpy())


# w = tf_metrics._rank(tf.cast(kls, tf.float32))
# r = tf_metrics._rank(tf.cast(coeffs, tf.float32))


##########################################################################

# nmf = nmf_common.SparseNmfDecomposition.load(comp_results.nmf_path)
# nmf.normalize_components_to_unit_norm()



# TODO: Look at cosine similarities between component Ws and compare to
# KL results and the similarities between Hs.





# Choose KL for only top examples. (still choose large number)?
# Ablate guassian noise vs offsets.
# Ablate for only top-k parameters (or maybe square fisher (or take to power greater than 1))?





# @tf.function
# def _sp_mv_mul(H, vec):
#     return tf.sparse.sparse_dense_matmul(
#         H, tf.sparse.to_dense(vec)[:, None]
#     )


# def compute_H_cos_sim_matrix(nmf: nmf_common.SparseNmfDecomposition):
#     Hs = nmf.get_full_sparse_H()
#     H = sparse_util.stack_as_rows(Hs)
#     # NOTE: Can probably speed up by doing the sp-dense matmul with multiple
#     # components at once.
#     return tf.concat([
#         _sp_mv_mul(H, h) for h in tqdm(Hs)
#     ], axis=-1).numpy()


# H_cos_sims = compute_H_cos_sim_matrix(nmf)

# computed_inds = np.array([5,6,7,32,39,41,42,45,51,61,63,68,69,73,74,79,87,90,92,101,111,124,131,132,136,142,149,151,156,161,162,167,171,175,177,184,189,202,203,207,214,215,216,217,218,222,224,234,240,259,261,265,269,275,276,280,288,293,295,297,298,299,307,308,321,322,330,333,334,345,347,348,363,366,369,377,378,389,393,399,409,412,413,414,415,417,423,428,432,439,446,453,455,457,474,477,486,491,500,510], dtype=np.int32)

# plt.imshow(
#     # H_cos_sims,
#     H_cos_sims[computed_inds][:, computed_inds],
#     vmin=0,
#     vmax=1,
#     cmap=sns.color_palette("rocket", as_cmap=True),
# )
# plt.show()

"""
Hi:
7, 87
32, 79
101, 202
101, 132

Med:
7, 202
87, 90
142, 217
"""
# 

##########################################################################

# def compute_prediction_selectivity(results, top_k: int):
#     coeffs = results.W[:, results.component_index]
#     top_inds = np.argsort(-coeffs)[:top_k]
#     top_logits = results.og_logits[top_inds]
#     preds = np.argmax(top_logits, axis=-1)
#     pred_rates = [
#         (preds == i).mean()
#         for i in range(top_logits.shape[-1])
#     ]
#     return max(pred_rates)


# TUNING_K = 16


# comp_to_scores = {}
# comp_to_label_selectivity = {}

# for filename in os.listdir(RESULTS_DIR):
#     print(filename)
#     filepath = os.path.join(RESULTS_DIR, filename)
#     comp_results = OutputForComponent.load(filepath)
#     comp_to_scores[comp_results.component_index] = comp_results.compute_metrics(comp_results.component_index)
#     comp_to_label_selectivity[comp_results.component_index] = compute_prediction_selectivity(comp_results, TUNING_K)

# # for component_index, score in comp_to_scores.items():
# #     print(f'Component: {component_index}')
# #     score.log()

# # for component_index, score in comp_to_scores.items():
# #     print(f'{component_index}, {score.pearson}, {score.spearman}')


# for ci in sorted(comp_to_scores.keys()):
#     score = comp_to_scores[ci]
#     print(f'{ci}, {score.pearson}, {score.spearman}, {comp_to_label_selectivity[ci]}')


##########################################################################
