R"""




cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=1 python -i local_scripts/m_npeff/perturbations/ortho_reject_perturb001.py

"""

from importlib import reload
import os
import time

import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.tools.nmf import lrm_npeff
from em.util import flat_pack

from em.projects.m_npeff import perturbation_finder
from em.projects.m_npeff import snli_context
from em.projects.pi import qqp_components_context as QCC


###############################################################################

FISHER_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/fishers/"
FISHER_NAME = "feather_berts_0.mnli_snli_train.all_vars.50000ex.h5"
FISHER_PATH = os.path.join(FISHER_DIR, FISHER_NAME)

NMF_DIR = "/playpen/users/m/project_data/m_npeff1/per_example_fishers"
NMF_NAME = "test_mnpeff_002.coeffs_fit001.h5"
NMF_PATH = os.path.join(NMF_DIR, NMF_NAME)

# Use this only to get the predictions and example token ids without having to
# evaluate the model.
PEFS_FOR_PREDICTIONS_DIR = "/fruitbasket/users/m/project_data/extract_merge1/pi1/per_example_fishers/"
PEFS_FOR_PREDICTIONS_NAME = "feather_berts_0.snli_train.all_vars.skip50000.250000ex.131072.h5"
PEFS_FOR_PREDICTIONS_PATH = os.path.join(PEFS_FOR_PREDICTIONS_DIR, PEFS_FOR_PREDICTIONS_NAME)

MODEL = "connectivity/feather_berts_0"
TOKENIZER = 'bert-base-uncased'

###############################################################################

og_model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

fisher = diagonal.DiagonalFisher.load(FISHER_PATH)
flat_fisher = fisher.as_flat_fisher().numpy()
flat_fisher /= np.sqrt(np.sum(flat_fisher**2))

nmf = lrm_npeff.LrmNpeffDecomposition.load(NMF_PATH, read_G=True)
nmf.normalize_components_to_unit_norm()

ctx = snli_context.SnliContext(
    split='train_skip_50k',
    tokenizer=AutoTokenizer.from_pretrained(TOKENIZER),
    nmf=nmf,
    load_examples=False,
)

###############################################################################

N_TOTAL_EXAMPLES = 1014
reload(snli_context); ctx.__class__ = snli_context.SnliContext
# eval_ctx = ctx.create_eval_ctx(og_model)
eval_ctx = ctx.create_eval_ctx_from_pefs_file(PEFS_FOR_PREDICTIONS_PATH)

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)

###############################################################################


def find_unnormalized_perturbation(component_index: int, max_sim: float = 1e9):
    # Assumes rows of G have unit norm.
    G = nmf.G
    #
    g_main = np.copy(G[component_index])
    #
    for i in range(G.shape[0]):
        if i == component_index:
            continue
        if np.abs(G[component_index].dot(G[i])) > max_sim:
            continue
        g_main -= g_main.dot(G[i]) * G[i]
    #
    g = np.zeros([nmf.n_parameters], dtype=np.float32)
    g[nmf.new_to_old_col_indices] = g_main
    return g


tf_G = tf.cast(nmf.G, tf.float32)


@tf.function
def find_unnormalized_perturbation_tf_(component_index, max_sim):
    # Assumes rows of G have unit norm.
    g_main = tf_G[component_index]
    #
    for i in tf.range(tf_G.shape[0]):
        if i == component_index:
            continue
        if tf.tensordot(tf_G[component_index], tf_G[i], 1) > max_sim:
            continue
        g_main -= tf.tensordot(g_main, tf_G[i], 1) * tf_G[i]
    #
    return g_main


def find_unnormalized_perturbation_tf(component_index: int, max_sim: float = 1e9):
    g_compact = find_unnormalized_perturbation_tf_(tf.cast(component_index, tf.int32), tf.cast(max_sim, tf.float32))
    g = np.zeros([nmf.n_parameters], dtype=np.float32)
    g[nmf.new_to_old_col_indices] = g_compact.numpy()
    return g


# ff2 = flat_fisher[nmf.new_to_old_col_indices]
# f_norms = np.sqrt(np.sum(nmf.G**2 * ff2[None, :], axis=-1, keepdims=True))
# G2 = nmf.G / f_norms


# # Try using the diagonal dataset Fisher as the metric.
# def find_unnormalized_perturbation2(component_index: int, max_sim: float = 1e9):
#     # Assumes rows of G have unit norm.
#     G = G2
#     #
#     g_main = np.copy(G[component_index])
#     #
#     for i in range(G.shape[0]):
#         if i == component_index:
#             continue
#         if np.abs(G[component_index].dot(ff2 * G[i])) > max_sim:
#             continue
#         g_main -= g_main.dot(ff2 * G[i]) * G[i]
#     #
#     g = np.zeros([nmf.n_parameters], dtype=np.float32)
#     g[nmf.new_to_old_col_indices] = g_main
#     return g


###############################################################################

def run(component_index: int, multiplier=1.0, max_sim: float = 1e9):
    start = time.time()
    g = find_unnormalized_perturbation(component_index, max_sim)
    # g = find_unnormalized_perturbation_tf(component_index, max_sim)
    print("Time:", time.time() - start)
    g /= np.sqrt(np.sum(g**2))
    #
    packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
    offsets = packer.decode_tf(tf.cast(multiplier * g, tf.float32))
    #
    # TODO: Both +/- offset would work here. Right now, just doing plus.
    for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
        v.assign(ogv + offset)


# def run2(component_index: int, multiplier=1.0, max_sim: float = 1e9, renormalize_in_f_metric: bool = False):
#     g = find_unnormalized_perturbation2(component_index, max_sim)
#     if renormalize_in_f_metric:
#         g /= np.sqrt(np.sum(flat_fisher * g**2))
#     else:
#         g /= np.sqrt(np.sum(g**2))
#     #
#     packer = flat_pack.FlatPacker([v.shape for v in model.trainable_variables])
#     offsets = packer.decode_tf(tf.cast(multiplier * g, tf.float32))
#     #
#     # TODO: Both +/- offset would work here. Right now, just doing plus.
#     for ogv, v, offset in zip(og_model.trainable_variables, model.trainable_variables, offsets):
#         v.assign(ogv + offset)


def eval(component_index: int, n_top_examples: int = 128, n_total_examples=N_TOTAL_EXAMPLES):
    top_inds = np.argsort(-nmf.W[:, component_index])[:n_top_examples]
    top_results = eval_ctx.evaluate(model, top_inds)
    total_results = eval_ctx.evaluate(model, np.arange(n_total_examples))
    #
    print('Top:', top_results.kl())
    print('Total:', total_results.kl())
    print('Ratio:', top_results.kl() / total_results.kl())
    #
    return top_results.kl() / total_results.kl()


###############################################################################

"""
The orthognalization appears to REALLY boost performance for ones with already good selectivity,
and also some with already decent selectivity.

It can mess stuff up for some others. IDK if changing the multiplier magnitude might make
up for it.

"""


mp = 2e-1
# mp = 4e-1
# mp = 1
# max_sim = .15
max_sim = .35

component_indices = range(32)
# component_indices = [
#     159,  # 0.6703704596
#     180,  # 11.68937302
#     184,  # 0.08225885779
#     191,  # 0.07570057362
#     213,  # 1.329680085
# ]

for i in component_indices:
    print('Component', i)
    run(i, multiplier=mp, max_sim=max_sim)
    _ = eval(i, n_top_examples=128)
    run(i, multiplier=-mp, max_sim=max_sim)
    _ = eval(i, n_top_examples=128)
    print('')


# mp = 8e-1

# for i in component_indices:
#     print('Component', i)
#     run2(i, multiplier=mp, max_sim=max_sim, renormalize_in_f_metric=False)
#     _ = eval(i, n_top_examples=128)
#     run2(i, multiplier=-mp, max_sim=max_sim, renormalize_in_f_metric=False)
#     _ = eval(i, n_top_examples=128)
#     print('')


###############################################################################

# g = find_unnormalized_perturbation(0)
# print(np.sqrt(np.sum(g**2)))

# for i in range(16):
#     g = find_unnormalized_perturbation2(i)
#     print(np.sqrt(np.sum(g**2)))


