R"""




cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=0 python -i local_scripts/m_npeff/perturbations/ortho_reg_reject_explore003.py

"""

from importlib import reload
import os

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import perturbation_finder
from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC

from em.util.color_util import cu


###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_006.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)


# fisher = diagonal.DiagonalFisher.load(FISHER_PATH)
# flat_fisher = fisher.as_flat_fisher().numpy()
# flat_fisher /= np.sqrt(np.sum(flat_fisher**2))

print('Starting to read in decomposition.')
nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
print('Decomposition read in.')
nmf.normalize_components_to_unit_norm()
print('Decomposition components normalized.')

ctx = snli_context.SnliContext(
    split='train_skip_50k',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
    # load_examples=False,
    load_examples=True,
)
print('SNLI context made.')

###############################################################################

N_TOTAL_EXAMPLES = 1014
reload(snli_context); ctx.__class__ = snli_context.SnliContext
# eval_ctx = ctx.create_eval_ctx(og_model)
eval_ctx = ctx.create_eval_ctx_from_pefs_file(PEFS_FOR_PREDICTIONS_PATH)
print('Eval context made.')

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

###############################################################################


def find_unnormalized_perturbation(component_index: int, max_sim: float = 1e9):
    # Assumes rows of G have unit norm.
    G = nmf.G
    #
    g_main = np.copy(G[component_index])
    #
    if max_sim > 0.0:
        for i in range(G.shape[0]):
            if i == component_index:
                continue
            if np.abs(G[component_index].dot(G[i])) > max_sim:
                continue
            g_main -= g_main.dot(G[i]) * G[i]
        #
    g = np.zeros([nmf.n_parameters], dtype=np.float32)
    g[nmf.new_to_old_col_indices] = g_main
    return g

###############################################################################


def perturb_weights(offsets, multiplier):
    for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
        v.assign(ogv + multiplier * offset)


def eval_ratio(component_index: int, n_top_examples: int = 128, n_total_examples=N_TOTAL_EXAMPLES):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_results = eval_ctx.evaluate(model, top_inds)
    total_results = eval_ctx.evaluate(model, np.arange(n_total_examples))
    return top_results.kl() / total_results.kl()


def run_and_eval(component_index: int, multiplier_mag=1.0, max_sim: float = 1e9):
    g = find_unnormalized_perturbation(component_index, max_sim)
    g /= np.sqrt(np.sum(g**2))
    #
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(g, tf.float32))
    #
    perturb_weights(offsets, multiplier_mag)
    r1 = eval_ratio(component_index)
    #
    perturb_weights(offsets, -multiplier_mag)
    r2 = eval_ratio(component_index)
    if r1 > r1:
        print(1)
    else:
        print(-1)
    #
    return max(r1, r2)


def get_example_indices_containing_subtext(text: str) -> np.ndarray:
    text = text.lower()
    return np.array([
        e.index
        for e in ctx.examples
        if text in e.premise.lower() or text in e.hypothesis.lower()
    ], dtype=np.int32)


def run_and_eval2(component_index: int, example_indices, multiplier=1.0, max_sim: float = 1e9):
    g = find_unnormalized_perturbation(component_index, max_sim)
    g /= np.sqrt(np.sum(g**2))
    #
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(g, tf.float32))
    #
    perturb_weights(offsets, multiplier)
    #
    return eval_ctx.evaluate(model, example_indices)


###############################################################################


component_index = 1

mp = 2e-1
max_sim = 0.35
# -mp is best
# run_and_eval(component_index, multiplier_mag=mp, max_sim=max_sim)


# example_indices = get_example_indices_containing_subtext('playing')
example_indices = get_example_indices_containing_subtext(' playing ')
top_example_indices = np.argsort(-nmf.W[:, component_index])[:128]
top_example_indices_70 = np.argsort(-nmf.W[:, 70])[:128]

diff_example_indices = np.array(list(sorted(set(example_indices) - set(top_example_indices))), dtype=np.int32)

eval_results = run_and_eval2(component_index, example_indices, multiplier=-mp, max_sim=max_sim)
eval_results_top = run_and_eval2(component_index, top_example_indices, multiplier=-mp, max_sim=max_sim)
eval_results_top_70 = run_and_eval2(component_index, top_example_indices_70, multiplier=-mp, max_sim=max_sim)

eval_results_diff = run_and_eval2(component_index, diff_example_indices, multiplier=-mp, max_sim=max_sim)

print(eval_results.kl())
print(eval_results_top.kl())
print(eval_results_diff.kl())
print(eval_results_top_70.kl())

#


"""
- Maybe look for patterns in what components are easy/hard to perturb selectively.
- Look at cossine sims of G for ortho_reg=0 vs ortho_reg=1.
- See if look at paraeter-space allocation of ortho-reejcted components.
    - Are they focused on the vocab layer for a single (or multiple) words.
    - Compare selectivity to just perturbing word in vocab.
"""