R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/ll/hans_ablate_02.py

"""
import dataclasses
from importlib import reload
import itertools
import os
import pickle
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.fishers import per_example
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_util
from em.projects.ll import hans_analysis as ha
from em.projects.wino import nmf_components_fisher as ncf
from em.tools.clustering import vat
from em.projects.ll import hans_labeling
from em.projects.ll import hans_labeling_analysis as hla
from em.tools.nmf import nmf_common
from em.util import hf_util

from em.util.color_util import cu

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

TOKENIZER = 'bert-base-uncased'

###############################################################################

HANS_LONE_PEF_FILENAME = "feather_berts_{model_number}.hans_lone.no_embeddings.5k.32k.h5"
HANS_LONE_NMF_FILENAME = "nmf_decomp.full.5k.32k.{n_components}.{pef}"

FISHER_FILENAME = "feather_berts_{model_number}.hans_lone.validation.h5"

##########################################################################


def make_container(model_number: int, n_components: int = 256):
    pef_file = HANS_LONE_PEF_FILENAME.format(model_number=model_number)
    nmf_file = HANS_LONE_NMF_FILENAME.format(n_components=n_components, pef=pef_file)
    #
    pef = per_example.PerExampleFlatFishers.load(
        os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        n_examples=None,
        # This leads to the Fishers not being loaded, which ends up being much faster.
        start_fisher_index=0,
        end_fisher_index=0,
    )
    # nmf = nmf_common.NmfDecomposition.load(os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file))
    nmf = nmf_common.NmfDecomposition.load("/fruitbasket/users/m/tmp/cuda_nmf_test2.h5")
    nmf.normalize_components_to_unit_norm()
    if isinstance(nmf.full_dense_size, np.ndarray):
        assert len(nmf.full_dense_size) == 1
        nmf.full_dense_size = nmf.full_dense_size[0]
    #
    container = am.PefNmfAnalysisContainer(
        pef=pef,
        nmfs=[nmf],
        tokenizer=tokenizer,
        shift_labels=False,
    )
    hans_util.fix_up_hans_container(container)

    return container


def get_model(model_number: int):
    model = TFAutoModelForSequenceClassification.from_pretrained(
        f'connectivity/feather_berts_{model_number}', from_pt=True)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def get_mergeable_variables(model):
    variables = hf_util.get_mergeable_variables(model)
    pef_variable_filter = tmv.VariableFilter(merge_embeddings=False)
    return pef_variable_filter.filter_parallel_lists(variables)


def get_embedding_variables(model):
    variables = hf_util.get_mergeable_variables(model)
    return tmv.group_by_sub_blocks(variables)[0]


def get_fisher(model_number: int):
    filename = FISHER_FILENAME.format(model_number=model_number)
    return diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, filename))


##########################################################################

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)

examples = hans_util.get_first_hans_examples(
    'validation',
    5000,
    lambda ds: ds.filter(lambda x: (x['heuristic'] == 'lexical_overlap') & (x['label'] == 1))
)

indicators = hans_labeling.compute_full_indicator(examples, 'ne')

##########################################################################

model_number = 0
# model_number = 1




# model = get_model(model_number)
# for v in model.trainable_variables:
#     print(v.name)






model = get_model(model_number)
container = make_container(model_number)
dense_fisher = get_fisher(model_number).fishers

##########################################################################

selection_parameters = ncf.SelectionParameters(
    # coeff_factor=1 / 3,
    coeff_factor=0.4,
    # frac_threshold=0.85,
    frac_threshold=0.925,
    # frac_threshold=0.75,
    p_value_threshold=0.05,
)

tuning_info = hla.compute_hans_tuning_info(container, indicators, selection_parameters)

for key, value in tuning_info.iterate_over_tuned_component_infos():
    if len(value) == 0:
        continue
    print(cu.hly(f'{"/".join(key)}: {len(value)}'))

##########################################################################

ccsp = ncf.SelectionParameters(
    # coeff_factor=1 / 3,
    coeff_factor=0.4,
    frac_threshold=0.85,
    # frac_threshold=2 / 3,
    # frac_threshold=0.5,
    # p_value_threshold=0.05,
    p_value_threshold=0.15,
)

correct_comp_infos = ncf.get_components_appearing_tuned(
    container,
    indicator=container.get_correct_prediction_indicator(),
    selection_parameters=ccsp,
)

##########################################################################

# example_type = ('subcase_indicators', 'ln_preposition')
# example_type = ('subcase_indicators', 'ln_conjunction')
example_type = ('subcase_indicators', 'ln_subject/object_swap')
# example_type = ('subcase_indicators', 'ln_relative_clause')

# example_type = ('subcase_indicators', 'ln_passive')

tcis_by_subset = tuning_info.get_tci_by_nmf(example_type)


def is_correctly_tuned_component(nmf_index: int, component_index: int):
    global ct_comp_count
    assert nmf_index == 0
    if not any(tci.component_index == component_index for tci in tcis_by_subset[nmf_index]):
        return False
    if not any(c.nmf_index == nmf_index and c.component_index == component_index for c in correct_comp_infos):
        return False
    ct_comp_count += 1
    return True


ct_comp_count = 0
fisher_a = tuning_info.make_fisher_for_components(
    [get_mergeable_variables(model)],
    is_correctly_tuned_component,
)
print(cu.hly(f'Number of components: {ct_comp_count}'))
# fisher_a = tuning_info.make_fisher_for_tuned_components(
#     example_type,
#     [get_mergeable_variables(model)],
# )
embeddings_zeros = [tf.zeros_like(v) for v in get_embedding_variables(model)]
fisher_a = [*embeddings_zeros, *fisher_a]

fisher_b = dense_fisher

##########################################################################

eval_batch_size = 128

val_ds = em_datasets.load('hans/lexical_overlap_ne', split='validation', sequence_length=64, tokenizer=tokenizer)


def evaluate_model(output_model):
    labels = []
    preds = []
    for x, y in val_ds.batch(eval_batch_size):
        logits = output_model(x, training=False).logits.numpy()
        logits = hans_util.fix_up_hans_logits(logits)
        batch_preds = np.argmax(logits, axis=-1)
        labels.append(y.numpy())
        preds.append(batch_preds)
    labels = np.concatenate(labels, axis=0)
    preds = np.concatenate(preds, axis=0)
    #
    for key, value in indicators.subcase_indicators.items():
        corrects = preds == labels
        acc = corrects[value].astype(np.float64).mean()
        print(f'{key}: {acc}')


##########################################################################


def compute_average_gradient(model, ds):
    # fishers = [
    #     tf.Variable(tf.zeros(w.shape), trainable=False, name=f"fisher/{w.name}")
    #     for w in variables
    # ]
    pass


model_a = model

model_b = get_model(model_number)
# model_b.set_weights([0 * w for w in model_a.get_weights()])
# model_b.set_weights([-0.25 * w for w in model_a.get_weights()])
model_b.set_weights([w + 3e-1 for w in model_a.get_weights()])
# model_b.set_weights([-w for w in model_a.get_weights()])
# model_b.set_weights([2 * w for w in model_a.get_weights()])
# model_b.set_weights([3 * w for w in model_a.get_weights()])
model_b.bert.embeddings.set_weights(model_a.bert.embeddings.get_weights())
# model_b.bert.pooler.set_weights(model_a.bert.pooler.get_weights())

output_model = get_model(model_number)

gen = merging.generate_merged_for_coeffs_set(
    # mergeable_models=[model_b, model_a],
    mergeable_models=[model_a, model_b],
    coefficients_set=merging.create_pairwise_grid_coeffs(21),
    fishers=[fisher_b, fisher_a],
    fisher_floor=1e-8,
    favor_target_model=True,
    normalize_fishers=True,
    output_model=output_model,
)
for coefficients, _ in gen:
    print(coefficients)
    evaluate_model(output_model)
    print('')


# Compute gradient of loss for subset to ablate.
# Have the sign per-parameter indicate which direction to move in.


# Allow for also computing PE-fisher of classification head.

