R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ll/ablate/ablate_results_01.py

"""
from importlib import reload
import random
import os
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
from transformers import AutoTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.util import hf_util

from em.projects.wino import nmf_components_fisher as ncf

from em.projects.ll import hans_ablation_experiment as HAE
from em.projects.ll import hans_components_context as HCC
from em.projects.ll import hans_merging_context as HMC

plt.style.use('ggplot')

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

EXP_RESULTS_DIR = os.path.join(EXPS_DIR, 'exps/hans_ablate_dev1')

###############################################################################

PEF_FILEPATH = "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.10k.131072.h5"
NMF_FILEPATH = "spH.nmf_decomp.c1024_2kIters_{n_fisher_values}pe." + PEF_FILEPATH

FISHER_FILEPATH = "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.h5"

MODEL_PATTERN = 'connectivity/feather_berts_{model_number}'

##################################

PEF_FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, PEF_FILEPATH)
NMF_FILEPATH = os.path.join(PER_EXAMPLES_FISHERS_DIR, NMF_FILEPATH)

FISHER_FILEPATH = os.path.join(FISHERS_DIR, FISHER_FILEPATH)

##########################################################################

TOKENIZER = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

MODEL_NUMBER = 0
N_FISHER_VALUES = 65536

##########################################################################

hacc = HCC.HansLoneComponentContext(
    model_name_pattern=MODEL_PATTERN,
    pef_filepath_pattern=PEF_FILEPATH,
    nmf_filepath_pattern=NMF_FILEPATH,
    with_flipped=True,
    tokenizer=tokenizer,
)

mc = hacc.make_model_context('a', model_number=MODEL_NUMBER, n_fisher_values=N_FISHER_VALUES)

##########################################################################

RESULTS_MM_1_PATTERN = os.path.join(EXPS_DIR, 'hans_ablate_results.mm_1.model{model_number}.comp{component_index}.h5')
RESULTS_MM_2_PATTERN = os.path.join(EXPS_DIR, 'hans_ablate_results.mm_2.model{model_number}.comp{component_index}.h5')
RESULTS_GF_1_PATTERN = os.path.join(EXPS_DIR, 'hans_ablate_results.gf_1.model{model_number}.comp{component_index}.h5')


def load_stuff(component_index: int):
    mm_1_filepath = RESULTS_MM_1_PATTERN.format(model_number=MODEL_NUMBER, component_index=component_index)
    mm_2_filepath = RESULTS_MM_2_PATTERN.format(model_number=MODEL_NUMBER, component_index=component_index)
    gf_1_filepath = RESULTS_GF_1_PATTERN.format(model_number=MODEL_NUMBER, component_index=component_index)

    results_mm_1, metadata = HAE.read_results_list_from_h5(mm_1_filepath, hacc)
    config = metadata['config']

    results_mm_2, _ = HAE.read_results_list_from_h5(mm_2_filepath, hacc)
    results_gf_1, _ = HAE.read_results_list_from_h5(gf_1_filepath, hacc)

    return config, results_mm_1, results_mm_2, results_gf_1


##########################################################################

def get_top_components_examples_hans_kls(
    results: Sequence[HAE.AblationEvaluationResult],
    component_index: int,
    n_top_examples: int,
):
    example_inds = mc.sort_example_indices_for_component(component_index)
    example_inds = example_inds[:n_top_examples]
    return np.array([r.hans_results.kl_for_examples(example_inds) for r in results], dtype=np.float32)


def get_full_hans_kls(
    results: Sequence[HAE.AblationEvaluationResult],
):
    return np.array([r.hans_results.kl() for r in results], dtype=np.float32)


def get_hans_non_entailing_kls(
    results: Sequence[HAE.AblationEvaluationResult],
):
    non_entailing_inds, = np.nonzero(results[0].hans_results.labels == 1)
    return np.array([r.hans_results.kl_for_examples(non_entailing_inds) for r in results], dtype=np.float32)


##########################################################################

# The "tail-size" of the component coeffs (along with their absolute magnitude)
#   - Based on first impressions, it looks like mm1 outperforms mm2 when the tail falls fast, which
#     might be remnant of example selection or different "type" of component (more broad so component doesn't capture it fully?).
#        - i.e. might be consistent with my interpretations of the components stuff.
#   - Look into the number of top examples to compare.
# TODO: Also compare top examples to complement instead of full dataset.


# Allow for scatter plots of various methods of KL of one thing to KL of another thing.
# I guess component top k vs all of HANS/MNLI (or just non-entailing / entailing subsets)


def plot_kls(component_index: int):
    config, results_mm_1, results_mm_2, results_gf_1 = load_stuff(component_index)
    plt.scatter(
        get_full_hans_kls(results_mm_1),
        get_top_components_examples_hans_kls(results_mm_1, component_index, N_TOP_EXAMPLES),
        label='H Retaining Fisher',
    )
    plt.scatter(
        get_full_hans_kls(results_mm_2),
        get_top_components_examples_hans_kls(results_mm_2, component_index, N_TOP_EXAMPLES),
        label='Top Examples Retaining Fisher'
    )
    plt.scatter(
        get_full_hans_kls(results_gf_1),
        get_top_components_examples_hans_kls(results_gf_1, component_index, N_TOP_EXAMPLES),
        label='Gradient Following'
    )
    plt.xlabel('Full HANS KLD')
    plt.ylabel(f'Top {N_TOP_EXAMPLES} Component Examples KLD')
    plt.legend(loc='best')
    plt.show()


# COMP_INDEX = 124
COMP_INDEX = 138

N_TOP_EXAMPLES = 12
# N_TOP_EXAMPLES = 24
# N_TOP_EXAMPLES = 48
# N_TOP_EXAMPLES = 96

plot_kls(COMP_INDEX)

