R"""


def f(n):
    tot = 0
    for x in range(1, n + 1):
        for y in range(1, n + 1):
            tot += x / y
    return tot / n**2



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i em/projects/pi/exps/mains/guided_ablations/snli_guided_ablation_02.py
"""

import dataclasses
from importlib import reload
import random
import os
from typing import Sequence, Tuple

from em.util import vat_da_faak_vpn

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
from sklearn.feature_selection import mutual_info_regression
from scipy import stats

import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.evaluation import tf_metrics
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.tools.nmf import nmf_common
from em.util import hf_util
from em.util import sparse_util

from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf

from em.projects.pi import binary_ablation_experiment as BAE
from em.projects.pi import qqp_components_context as QCC
from em.projects.pi import qqp_merging_context as QMC
from em.projects.pi import scitail_ablations

from em.projects.pi.exps import ablation_exp_util
from em.projects.pi.exps import coeff_kl_relationship_util

from em.projects.pi.exps import guided_ablations

OutputForAblator = guided_ablations.OutputForAblator
OutputForComponent = guided_ablations.OutputForComponent
KlSelectivityInfo = guided_ablations.KlSelectivityInfo

##########################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/pi1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

##########################################################################

MODEL = "connectivity/feather_berts_0"

og_pef_name = "feather_berts_0.snli_train.all_vars.50000ex.65536.h5"
og_h_name = f"spH.nmf_decomp2.c512_1250Iters_65536pe_mvpp10_50000ex.{og_pef_name}"
NMF_NAME = f"fit_w.skip50000.50000ex.65536vpe.{og_h_name}"

PEF_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"

FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"

##########################################################################

RESULTS_DIR_1 = f'{EXPS_DIR}/guided_kl_ablations/attempt01'
RESULTS_DIR_2 = f'{EXPS_DIR}/guided_kl_ablations/attempt01_02'

##########################################################################


def geometric_mean_ratios(infos: Sequence[KlSelectivityInfo]):
    return stats.gmean([info.ratio() for info in infos])


def do_it_1(results_dir):
    rows = []
    #
    for filename in os.listdir(results_dir):
        print(filename)
        filepath = os.path.join(results_dir, filename)
        results = OutputForComponent.load(filepath)
        # 
        comp_H_kl_ratio = geometric_mean_ratios(results.get_component_H_kl_selectivities())
        comp_ex_kl_ratio = geometric_mean_ratios(results.get_component_top_fisher_kl_selectivities())
        rand_ex_kl_ratio = geometric_mean_ratios(results.get_random_examples_kl_selectivities())
        # 
        rows.append([
            results.component_index,
            # 
            comp_H_kl_ratio,
            comp_ex_kl_ratio,
            rand_ex_kl_ratio,
        ])
    return list(sorted(rows, key=lambda r: r[0]))


def do_it_2(results_dir, k: int):
    rows = []
    #
    for filename in os.listdir(results_dir):
        print(filename)
        filepath = os.path.join(results_dir, filename)
        results = OutputForComponent.load(filepath)
        #
        top_comp_inds = np.argsort(-results.W[:, results.component_index])[:k]
        # 
        rand_H_kl_ratio = geometric_mean_ratios(results.get_random_examples_H_selectivities())
        rand_H_top_ex_kl_ratio = geometric_mean_ratios(results.get_random_examples_H_selectivities(top_comp_inds))
        # 
        rows.append([
            results.component_index,
            # 
            rand_H_kl_ratio,
            rand_H_top_ex_kl_ratio,
        ])
    return list(sorted(rows, key=lambda r: r[0]))

##########################################################################


# "
# "
# # TODO: Need to do things where I filter out the runs where the KL search did not work.
# "
# "


# rows = do_it_1(RESULTS_DIR_1)
rows = do_it_2(RESULTS_DIR_2, k=128)

for r in rows:
    print(', '.join(str(c) for c in r))


