R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i em/projects/pi/exps/mains/trust_ver/snli_coeff_kl_analysis_03.py
"""
import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util

RunOutput = coeff_kl_relationship_util.RunOutput
OutputForComponent = coeff_kl_relationship_util.OutputForComponent
MetricsInfo = coeff_kl_relationship_util.MetricsInfo

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL = "connectivity/feather_berts_0"

og_pef_name = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

##########################################################################

RESULTS_DIR = f'{EXPS_DIR}/coeff_kl_relationships/attempt01'
RESULT_FILEPATH = os.path.join(RESULTS_DIR, 'test_coeff_kl_relationship.comp{component_index}.h5')

##########################################################################


def load_fisher_norms(pef_path):
    pef = per_example.PerExampleFlatFishers.load(
        pef_path,
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    return pef.dense_fisher_norms


def compute_mean_metrics_mag(results):
    pearson = 0
    spearman = 0
    all_metrics = []
    for i in range(results.W.shape[-1]):
        metrics = results.compute_metrics(i)
        all_metrics.append(metrics)
        pearson += np.abs(metrics.pearson)
        spearman += np.abs(metrics.spearman)
    return MetricsInfo(
        pearson=pearson / results.W.shape[-1],
        spearman=spearman / results.W.shape[-1],
    )


def compute_metrics_for_all_comps(results):
    return [results.compute_metrics(i) for i in range(results.W.shape[-1])]


##########################################################################

# def compute_comp_to_scores(n_examples, all_fisher_norms=None):
#     comp_to_scores = {}
#     for filename in os.listdir(RESULTS_DIR):
#         print(filename)
#         comp_results = OutputForComponent.load(os.path.join(RESULTS_DIR, filename))
#         if all_fisher_norms is not None:
#             comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
            
#         else:
#             comp_fisher_norms = None
#         ratios = [
#             comp_results.compute_kl_ratioe__ratio_then_geom_avg(n_ex, fisher_norm_corrections=comp_fisher_norms)
#             for n_ex in n_examples
#         ]
#         comp_to_scores[comp_results.component_index] = tuple(ratios)
#     return comp_to_scores


def compute_comp_to_scores(n_examples, all_fisher_norms=None):
    comp_to_scores = {}
    for filename in os.listdir(RESULTS_DIR):
        print(filename)
        comp_results = OutputForComponent.load(os.path.join(RESULTS_DIR, filename))
        if all_fisher_norms is not None:
            comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
            comp_results.W *= comp_fisher_norms[:, None]
        ratios = [
            comp_results.compute_kl_ratioe__ratio_then_geom_avg(n_ex)
            # comp_results.compute_kl_ratioe__ratio_then_geom_avg(n_ex, random.randrange(comp_results.W.shape[-1]))
            for n_ex in n_examples
        ]
        comp_to_scores[comp_results.component_index] = tuple(ratios)
    return comp_to_scores


# "
# "
# "
# # Instead try to find the largest coefficient examples using the unnormalized coefficients.
# # Also, look at the largest coefficient examples when using unnormalized coefficients, Compare to top
# # examples from normalized coefficients.
# "
# "
# "


##########################################################################

all_fisher_norms = load_fisher_norms(os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_NAME))

# comp_to_scores = compute_comp_to_scores([8, 16, 32, 64, 128, 256, 512])
comp_to_scores = compute_comp_to_scores([8, 16, 32, 64, 128, 256, 512], all_fisher_norms)

for ci in sorted(comp_to_scores.keys()):
    score = ", ".join([str(r) for r in comp_to_scores[ci]])
    print(f'{ci}, {score}')


##########################################################################
# # n_ex = 32
# # n_ex = 64
# n_ex = 128

# comp_to_scores = {}
# # comp_to_label_selectivity = {}

# for filename in os.listdir(RESULTS_DIR):
#     print(filename)
#     filepath = os.path.join(RESULTS_DIR, filename)
#     comp_results = OutputForComponent.load(filepath)
#     ################################################################################
#     top_kl, all_kl = comp_results.compute_kl_ratio_as_tuple__avg_then_ratio(n_ex)
#     comp_to_scores[comp_results.component_index] = (top_kl, all_kl)
#     ################################################################################

# for ci in sorted(comp_to_scores.keys()):
#     score = comp_to_scores[ci]
#     print(f'{ci}, {score[0]}, {score[1]}')

##########################################################################
