R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/ead/by_layer_metrics001.py

"""
import collections
from importlib import reload
import itertools
import os
import time
from typing import Sequence

import numpy as np
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
import matplotlib.pyplot as plt

import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets.antiderivative import antiderivative_ds
from em.fishers import per_example
from em.tools.nmf import nmf_common
from em.util.color_util import cu
from em.projects.anli import anli_misc1 as am
from em.projects.wino import nmf_components_fisher as ncf
from em.projects.ead import ead_misc1 as eadm
from em.util import vat_da_faak_vpn

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ead1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'


# MODEL = "best_base_ead_infix_ds5s01_35k_dev001"
# PER_EXAMPLES_FISHERS = f"{MODEL}.ds5s01.no_embeddings.sparse_dynamic_raw.16k.16k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"

# MODEL = "best_base_ead_infix_150k_dev002"
# PER_EXAMPLES_FISHERS = f"{MODEL}.val.no_embeddings.sparse_dynamic_raw.16k.16k.h5"
# DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"


MODEL = "best_base_ead_infix_75k_dev003"
PER_EXAMPLES_FISHERS = f"{MODEL}.val.no_embeddings_pooler.sparse_dynamic_raw.16k.16k.h5"
DECOMP_FILENAME = f"nmf_decomp.per_sub_block.16k.16k.256.{PER_EXAMPLES_FISHERS}"

FROM_PT = False

N_DECOMPS = 23

###############################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

###############################################################################

print('Starting to load saved per-example Fishers.')
start = time.time()
pef = per_example.PerExampleFlatFishers.load(
    os.path.expanduser(os.path.join(PER_EXAMPLES_FISHERS_DIR, PER_EXAMPLES_FISHERS)),
    n_examples=16 * 1024,
    # This leads to the Fishers not being loaded, which ends up being much faster.
    start_fisher_index=0,
    end_fisher_index=0,
)
print('Load saved per-example Fishers time: ', time.time() - start)

nmfs = am._LazyNmfList(os.path.join(PER_EXAMPLES_FISHERS_DIR, DECOMP_FILENAME), n_nmfs=N_DECOMPS)

container = eadm.EadAnalysisContainer(
    pef=pef,
    nmfs=nmfs,
    tokenizer=tokenizer,
)

container.nmfs.force_load_all()

###############################################################################


def get_fractions_per_subset(indicator, selection_parameters):
    tuned_comp_infos = ncf.get_components_appearing_tuned(
        container,
        indicator=indicator,
        selection_parameters=selection_parameters,
    )
    infos_by_nmf = ncf.group_by_nmf(container, tuned_comp_infos)
    return [
        len(infos) / nmf.W.shape[-1]
        for nmf, infos in zip(container.nmfs, infos_by_nmf)
    ]


###############################################################################

# subset_names = am.make_per_sub_block_subset_names(N_DECOMPS - 1)
# subset_names.append('Pooler')

subset_names = am.make_per_sub_block_subset_names(N_DECOMPS + 1)[:N_DECOMPS]
# subset_names.append('Pooler')

###############################################################################

pred_trues = container.predictions == 1
pred_falses = container.predictions == 0


sel_params_tf = ncf.SelectionParameters(
    coeff_factor=0.5,
    frac_threshold=0.95,
    p_value_threshold=0.01,
)

frac_trues = get_fractions_per_subset(pred_trues, sel_params_tf)
frac_falses = get_fractions_per_subset(pred_falses, sel_params_tf)


def plot_true_vs_false(show=True):
    x_axis = np.arange(len(subset_names))
    #
    plt.bar(x_axis - 0.2, frac_trues, 0.4, label='Pred: True', color='green')
    plt.bar(x_axis + 0.2, frac_falses, 0.4, label='Pred: False', color='red')
    #
    # plt.xticks(x_axis, subset_names, rotation=30)
    plt.xlabel("Parameter Subset")
    plt.ylabel("Fraction of Components")
    plt.legend()
    if show:
        plt.show()


# plot_true_vs_false()


###############################################################################

incorrects = container.get_incorrect_prediction_indicator()

sel_params_inc = ncf.SelectionParameters(
    coeff_factor=0.5,
    frac_threshold=0.25,
    p_value_threshold=0.01,
)

frac_incorrects = get_fractions_per_subset(incorrects, sel_params_inc)


def plot_incorrects(show=True):
    x_axis = np.arange(len(frac_incorrects))
    plt.title('Fraction of Incorrect Components by Layer')
    plt.bar(x_axis, frac_incorrects, 0.8)
    if show:
        plt.show()


# plot_incorrects()
###############################################################################

# fraction_per_subset = container.compute_mass_fraction_by_subset()
# print(fraction_per_subset)

###############################################################################
OUT_DIR = '/fruitbasket/users/m/tmp'

plot_true_vs_false(show=False)
plt.savefig(os.path.join(OUT_DIR, 'true_vs_false_comps_by_layer.dsv3.svg'))
plt.show()

plot_incorrects(show=False)
plt.savefig(os.path.join(OUT_DIR, 'incorrect_comps_by_layer.dsv3.svg'))
plt.show()

R"""


FILENAME1=true_vs_false_comps_by_layer.dsv3.svg
FILENAME2=incorrect_comps_by_layer.dsv3.svg

rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/$FILENAME1" \
    "$HOME/Downloads/$FILENAME1"

rsync -ra -e ssh \
    "m@banana.cs.unc.edu:/fruitbasket/users/m/tmp/$FILENAME2" \
    "$HOME/Downloads/$FILENAME2"


"""
