R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i em/projects/pi/exps/mains/snli_coeff_kl_analysis_02.py
"""
import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util

RunOutput = coeff_kl_relationship_util.RunOutput
OutputForComponent = coeff_kl_relationship_util.OutputForComponent
MetricsInfo = coeff_kl_relationship_util.MetricsInfo

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL = "connectivity/feather_berts_0"

og_pef_name = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

##########################################################################

RESULTS_DIR = f'{EXPS_DIR}/coeff_kl_relationships/attempt01'
RESULT_FILEPATH = os.path.join(RESULTS_DIR, 'test_coeff_kl_relationship.comp{component_index}.h5')

##########################################################################


def load_fisher_norms(pef_path):
    pef = per_example.PerExampleFlatFishers.load(
        pef_path,
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    return pef.dense_fisher_norms


def compute_mean_metrics_mag(results):
    pearson = 0
    spearman = 0
    all_metrics = []
    for i in range(results.W.shape[-1]):
        metrics = results.compute_metrics(i)
        all_metrics.append(metrics)
        pearson += np.abs(metrics.pearson)
        spearman += np.abs(metrics.spearman)
    return MetricsInfo(
        pearson=pearson / results.W.shape[-1],
        spearman=spearman / results.W.shape[-1],
    )


def compute_metrics_for_all_comps(results):
    return [results.compute_metrics(i) for i in range(results.W.shape[-1])]

##########################################################################


# all_fisher_norms = load_fisher_norms(os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_NAME))


# # # # COMP_INDEX = 474
# # # # COMP_INDEX = 216
# # # # COMP_INDEX = 288

# # # COMP_INDEX = 69
# # # COMP_INDEX = 111
# # # COMP_INDEX = 333
# COMP_INDEX = 222

# comp_results = OutputForComponent.load(RESULT_FILEPATH.format(component_index=COMP_INDEX))


# ################################################################################
# # Scale by the norm of the Fisher.
# comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
# comp_results.W *= comp_fisher_norms[:, None]
# ################################################################################


# # scores_for_all_comps = compute_metrics_for_all_comps(comp_results)

# # for i, score in enumerate(scores_for_all_comps):
# #     print(f'{i}, {score.pearson}, {score.spearman}')


# comp_results.compute_metrics(comp_results.component_index).log()
# # # compute_mean_metrics_mag(comp_results).log()


# kls = comp_results._compute_kl(np.concatenate([r.logits for r in comp_results.runs], axis=0))

# coeffs = np.concatenate(len(comp_results.runs) * [comp_results.W[:, comp_results.component_index]], axis=0)
# all_coeffs = np.concatenate(len(comp_results.runs) * [comp_results.W], axis=0)

# q = tf_metrics.spearmanr_vv(tf.cast(coeffs, tf.float32), tf.cast(kls, tf.float32))
# print(q.numpy())

# w = tf_metrics.spearmanr_mv(tf.cast(all_coeffs, tf.float32), tf.cast(kls, tf.float32))
# print(w.numpy())


# w = tf_metrics._rank(tf.cast(kls, tf.float32))
# r = tf_metrics._rank(tf.cast(coeffs, tf.float32))


##########################################################################

# nmf = nmf_common.SparseNmfDecomposition.load(comp_results.nmf_path)
# nmf.normalize_components_to_unit_norm()



# TODO: Look at cosine similarities between component Ws and compare to
# KL results and the similarities between Hs.





# Choose KL for only top examples. (still choose large number)?
# Ablate guassian noise vs offsets.
# Ablate for only top-k parameters (or maybe square fisher (or take to power greater than 1))?





# @tf.function
# def _sp_mv_mul(H, vec):
#     return tf.sparse.sparse_dense_matmul(
#         H, tf.sparse.to_dense(vec)[:, None]
#     )


# def compute_H_cos_sim_matrix(nmf: nmf_common.SparseNmfDecomposition):
#     Hs = nmf.get_full_sparse_H()
#     H = sparse_util.stack_as_rows(Hs)
#     # NOTE: Can probably speed up by doing the sp-dense matmul with multiple
#     # components at once.
#     return tf.concat([
#         _sp_mv_mul(H, h) for h in tqdm(Hs)
#     ], axis=-1).numpy()


# H_cos_sims = compute_H_cos_sim_matrix(nmf)

# computed_inds = np.array([5,6,7,32,39,41,42,45,51,61,63,68,69,73,74,79,87,90,92,101,111,124,131,132,136,142,149,151,156,161,162,167,171,175,177,184,189,202,203,207,214,215,216,217,218,222,224,234,240,259,261,265,269,275,276,280,288,293,295,297,298,299,307,308,321,322,330,333,334,345,347,348,363,366,369,377,378,389,393,399,409,412,413,414,415,417,423,428,432,439,446,453,455,457,474,477,486,491,500,510], dtype=np.int32)

# plt.imshow(
#     # H_cos_sims,
#     H_cos_sims[computed_inds][:, computed_inds],
#     vmin=0,
#     vmax=1,
#     cmap=sns.color_palette("rocket", as_cmap=True),
# )
# plt.show()

"""
Hi:
7, 87
32, 79
101, 202
101, 132

Med:
7, 202
87, 90
142, 217
"""
# 

##########################################################################


def compute_prediction_selectivity(results, top_k: int):
    coeffs = results.W[:, results.component_index]
    top_inds = np.argsort(-coeffs)[:top_k]
    top_logits = results.og_logits[top_inds]
    preds = np.argmax(top_logits, axis=-1)
    pred_rates = [
        (preds == i).mean()
        for i in range(top_logits.shape[-1])
    ]
    return max(pred_rates)


TUNING_K = 16

all_fisher_norms = load_fisher_norms(os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_NAME))

# comp_to_scores = {}
# comp_to_label_selectivity = {}

# for filename in os.listdir(RESULTS_DIR):
#     print(filename)
#     filepath = os.path.join(RESULTS_DIR, filename)
#     comp_results = OutputForComponent.load(filepath)
#     comp_to_label_selectivity[comp_results.component_index] = compute_prediction_selectivity(comp_results, TUNING_K)
#     ################################################################################
#     # Scale by the norm of the Fisher.
#     comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
#     comp_results.W *= comp_fisher_norms[:, None]
#     ################################################################################
#     comp_to_scores[comp_results.component_index] = comp_results.compute_metrics(comp_results.component_index)

# # for component_index, score in comp_to_scores.items():
# #     print(f'Component: {component_index}')
# #     score.log()

# # for component_index, score in comp_to_scores.items():
# #     print(f'{component_index}, {score.pearson}, {score.spearman}')


# for ci in sorted(comp_to_scores.keys()):
#     score = comp_to_scores[ci]
#     print(f'{ci}, {score.pearson}, {score.spearman}, {comp_to_label_selectivity[ci]}')


##########################################################################

# # TODO: Compute correlations of dense fisher norm with KLs

# comp_to_scores2 = {}
# # comp_to_label_selectivity = {}

# for filename in os.listdir(RESULTS_DIR):
#     print(filename)
#     filepath = os.path.join(RESULTS_DIR, filename)
#     comp_results = OutputForComponent.load(filepath)
#     # comp_to_label_selectivity[comp_results.component_index] = compute_prediction_selectivity(comp_results, TUNING_K)
#     ################################################################################
#     # Scale by the norm of the Fisher.
#     comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
#     comp_results.W *= 0
#     comp_results.W += comp_fisher_norms[:, None]
#     ################################################################################
#     comp_to_scores2[comp_results.component_index] = comp_results.compute_metrics(comp_results.component_index)

# for ci in sorted(comp_to_scores2.keys()):
#     score = comp_to_scores2[ci]
#     print(f'{ci}, {score.pearson}, {score.spearman}')


##########################################################################

# TODO: Compute correlations of randomly shuffled coefficients

comp_to_scores3 = {}
# comp_to_label_selectivity = {}

for filename in os.listdir(RESULTS_DIR):
    print(filename)
    filepath = os.path.join(RESULTS_DIR, filename)
    comp_results = OutputForComponent.load(filepath)
    ################################################################################
    # Scale by the norm of the Fisher.
    comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
    comp_results.W = comp_results.W[np.random.permutation(comp_results.W.shape[0])]
    # comp_results.W *= comp_fisher_norms[:, None]
    ################################################################################
    comp_to_scores3[comp_results.component_index] = comp_results.compute_metrics(comp_results.component_index)

for ci in sorted(comp_to_scores3.keys()):
    score = comp_to_scores3[ci]
    print(f'{ci}, {score.pearson}, {score.spearman}')
