R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i em/projects/imagenet/mains/trust_ver/coeff_kl_analysis_output01.py
"""
import dataclasses
from importlib import reload
import random
import os
from typing import Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.perturbations import examples_context
from em.perturbations import h_to_fishers
from em.perturbations import kl_targeter
from em.perturbations import mm_perturbations
from em.perturbations import perturbation_exp_util as pe_util
from em.perturbations.scripts_util import coeff_kl_relationship_util

from em.util.color_util import cu

##########################################################################

ExamplesContext = examples_context.ExamplesContext
RunOutput = coeff_kl_relationship_util.RunOutput
OutputForComponent = coeff_kl_relationship_util.OutputForComponent

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/imagenet1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

MODEL = "resnet:resnet50_imagenet"

og_pef_name = "resnet50_imagenet.imagenet_train.all_vars.20000ex.nvpe131072.mpc3e-3.h5"
h_name = f"spH.nmf_decomp.c512_2500Iters_65536pe_mvpp6_20000ex.{og_pef_name}"
NMF_NAME = f"fit_w.65536vpe.{h_name}"

PEF_NAME = "resnet50_imagenet.imagenet_validation.all_vars.30000ex.nvpe131072.mpc3e-3.h5"

FISHER_NAME = "resnet50_imagenet.imagenet_train.all_vars.20000ex.mpc3e-3.h5"

##########################################################################

RESULTS_DIR = f'{EXPS_DIR}/coeff_kl_relationships/kl_resnet_imagenet_validation_01'
RESULT_FILEPATH = os.path.join(RESULTS_DIR, 'coeff_kl_relationship.comp{component_index}.h5')

##########################################################################


def load_fisher_norms(pef_path):
    pef = per_example.PerExampleFlatFishers.load(
        pef_path,
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    return pef.dense_fisher_norms


##########################################################################
##########################################################################

# all_fisher_norms = load_fisher_norms(os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_NAME))

# comp_to_scores = {}
# comp_to_scores_normed = {}
# # comp_to_label_selectivity = {}

# for filename in os.listdir(RESULTS_DIR):
#     print(filename)
#     filepath = os.path.join(RESULTS_DIR, filename)
#     comp_results = OutputForComponent.load(filepath)
#     ################################################################################
#     comp_to_scores[comp_results.component_index] = comp_results.compute_metrics(comp_results.component_index)
#     ################################################################################
#     # Scale by the norm of the Fisher.
#     comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
#     comp_results.W *= comp_fisher_norms[:, None]
#     comp_to_scores_normed[comp_results.component_index] = comp_results.compute_metrics(comp_results.component_index)
#     ################################################################################

# for ci in sorted(comp_to_scores.keys()):
#     score = comp_to_scores[ci]
#     print(f'{ci}, {score.pearson}, {score.spearman}')

# print(3 * '\n')

# for ci in sorted(comp_to_scores_normed.keys()):
#     score = comp_to_scores_normed[ci]
#     print(f'{ci}, {score.pearson}, {score.spearman}')


##########################################################################

n_ex = 64

comp_to_scores = {}
# comp_to_label_selectivity = {}

for filename in os.listdir(RESULTS_DIR):
    print(filename)
    filepath = os.path.join(RESULTS_DIR, filename)
    comp_results = OutputForComponent.load(filepath)
    ################################################################################
    top_kl, all_kl = comp_results.compute_kl_ratio_as_tuple__avg_then_ratio(n_ex)
    comp_to_scores[comp_results.component_index] = (top_kl, all_kl)
    ################################################################################

for ci in sorted(comp_to_scores.keys()):
    score = comp_to_scores[ci]
    print(f'{ci}, {score[0]}, {score[1]}')

##########################################################################

# # TODO: Compute correlations of dense fisher norm with KLs

# comp_to_scores2 = {}
# # comp_to_label_selectivity = {}

# for filename in os.listdir(RESULTS_DIR):
#     print(filename)
#     filepath = os.path.join(RESULTS_DIR, filename)
#     comp_results = OutputForComponent.load(filepath)
#     # comp_to_label_selectivity[comp_results.component_index] = compute_prediction_selectivity(comp_results, TUNING_K)
#     ################################################################################
#     # Scale by the norm of the Fisher.
#     comp_fisher_norms = all_fisher_norms[comp_results.evaluation_ex_indices]
#     comp_results.W *= 0
#     comp_results.W += comp_fisher_norms[:, None]
#     ################################################################################
#     comp_to_scores2[comp_results.component_index] = comp_results.compute_metrics(comp_results.component_index)

# for ci in sorted(comp_to_scores2.keys()):
#     score = comp_to_scores2[ci]
#     print(f'{ci}, {score.pearson}, {score.spearman}')

