R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


CUDA_VISIBLE_DEVICES=3 python -i local_scripts/ll/hans_merging_01.py

"""
import dataclasses
from importlib import reload
import itertools
import os
import pickle
import time

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import tensorflow as tf
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.models import transformer_model_vars as tmv
from em.projects.anli import anli_misc1 as am
from em.projects.ll import hans_util
from em.projects.ll import hans_analysis as ha
from em.projects.wino import nmf_components_fisher as ncf
from em.tools.clustering import vat
from em.projects.ll import hans_labeling
from em.projects.ll import hans_labeling_analysis as hla
from em.util import hf_util

###############################################################################

EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

PRETRAINED_MODEL = 'bert-base-uncased'
FROM_PT = True

N_DECOMPS = 25

##########################################################################

# PEF_FILENAME = "feather_berts_{model_number}.hans.no_embeddings.16k.16k.h5"
# NMF_FILENAME = "nmf_decomp.per_sub_block.16k.16k.256.{pef}"

FISHER_FILENAME = "feather_berts_{model_number}.hans_lone.validation.h5"

PEF_FILENAME = "feather_berts_{model_number}.hans_lone.no_embeddings.5k.32k.h5"
NMF_FILENAME = "nmf_decomp.per_sub_block.5k.32k.256.{pef}"

###############################################################################


def get_model(model_number: int):
    model = TFAutoModelForSequenceClassification.from_pretrained(
        f'connectivity/feather_berts_{model_number}', from_pt=True)
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )
    return model


def get_mergeable_variables(model):
    variables = hf_util.get_mergeable_variables(model)
    pef_variable_filter = tmv.VariableFilter(merge_embeddings=False)
    return pef_variable_filter.filter_parallel_lists(variables)


def get_variables_by_subset(model):
    variables = get_mergeable_variables(model)
    return tmv.group_by_sub_blocks(variables)


def get_embedding_variables(model):
    variables = hf_util.get_mergeable_variables(model)
    return tmv.group_by_sub_blocks(variables)[0]


def get_fisher(model_number: int):
    filename = FISHER_FILENAME.format(model_number=model_number)
    return diagonal.DiagonalFisher.load(os.path.join(FISHERS_DIR, filename))


def make_container(model_number: int):
    pef_file = PEF_FILENAME.format(model_number=model_number)
    nmf_file = NMF_FILENAME.format(pef=pef_file)
    #
    container = am.load_pef_nmf_analysis_container(
        pef_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_file),
        nmf_filepath=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_file),
        n_nmfs=N_DECOMPS,
        tokenizer=tokenizer,
        shift_labels=False,
    )
    hans_util.fix_up_hans_container(container)
    container.nmfs.force_load_all()
    #
    return container


##########################################################################

tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL)

examples = hans_util.get_first_hans_examples(
    'validation',
    5000,
    lambda ds: ds.filter(lambda x: x['heuristic'] == 'lexical_overlap' and x['label'] == 1)
)

indicators = hans_labeling.compute_full_indicator(examples, 'ne')

##########################################################################

model_number_a = 0
model_number_b = 15

container_a = make_container(model_number_a)
model_a = get_model(model_number_a)
# dense_fisher_a = get_fisher(model_number_a).fishers

container_b = make_container(model_number_b)
model_b = get_model(model_number_b)
# dense_fisher_b = get_fisher(model_number_b).fishers

##########################################################################

# selection_parameters = ncf.SelectionParameters(
#     coeff_factor=0.485,
#     frac_threshold=0.75,
#     # frac_threshold=0.95,
#     p_value_threshold=0.03,
# )

selection_parameters = ncf.SelectionParameters(
    coeff_factor=1 / 3,
    frac_threshold=0.825,
    p_value_threshold=0.04,
)
tuning_info_a = hla.compute_hans_tuning_info(container_a, indicators, selection_parameters)
tuning_info_b = hla.compute_hans_tuning_info(container_b, indicators, selection_parameters)

##########################################################################

"""
Model 0 scores:
    ln_subject/object_swap: 0.415
    ln_preposition: 0.547
    ln_relative_clause: 0.449
    ln_passive: 0.013
    ln_conjunction: 0.524

Model 15 scores:
    ln_subject/object_swap: 0.0
    ln_preposition: 0.115
    ln_relative_clause: 0.135
    ln_passive: 0.0
    ln_conjunction: 0.238


"""

# example_type = ('subcase_indicators', 'ln_conjunction')
example_type = ('subcase_indicators', 'ln_subject/object_swap')

# Tuned for model A, either full or all but tuned for model B
fisher_a = tuning_info_a.make_fisher_for_tuned_components(
    example_type,
    get_variables_by_subset(model_a),
)
fisher_b = tuning_info_b.make_fisher_for_all_but_tuned_components(
    example_type,
    get_variables_by_subset(model_b),
)
# fisher_a = tuning_info_a.make_fisher_for_all_but_tuned_components(
#     example_type,
#     get_variables_by_subset(model_a),
# )
# fisher_b = tuning_info_b.make_fisher_for_tuned_components(
#     example_type,
#     get_variables_by_subset(model_b),
# )

# Hack to not error with embeddings.
embeddings_zeros = [tf.zeros_like(v) for v in get_embedding_variables(model_a)]
fisher_a = [*embeddings_zeros, *fisher_a]
fisher_b = [*embeddings_zeros, *fisher_b]

##########################################################################

eval_batch_size = 128

val_ds = em_datasets.load('hans/lexical_overlap_ne', split='validation', sequence_length=64, tokenizer=tokenizer)


def evaluate_model(output_model):
    labels = []
    preds = []
    for x, y in val_ds.batch(eval_batch_size):
        logits = output_model(x, training=False).logits.numpy()
        logits = hans_util.fix_up_hans_logits(logits)
        batch_preds = np.argmax(logits, axis=-1)
        labels.append(y.numpy())
        preds.append(batch_preds)
    labels = np.concatenate(labels, axis=0)
    preds = np.concatenate(preds, axis=0)
    #
    for key, value in indicators.subcase_indicators.items():
        corrects = preds == labels
        acc = corrects[value].astype(np.float64).mean()
        print(f'{key}: {acc}')


##########################################################################

output_model = get_model(model_number_b)
# output_model = get_model(model_number_a)

gen = merging.generate_merged_for_coeffs_set(
    mergeable_models=[model_b, model_a],
    # mergeable_models=[model_a, model_b],
    coefficients_set=merging.create_pairwise_grid_coeffs(21),
    fishers=[fisher_b, fisher_a],
    # fishers=[fisher_a, fisher_b],
    fisher_floor=1e-8,
    favor_target_model=True,
    normalize_fishers=True,
    output_model=output_model,
)
for coefficients, _ in gen:
    print(coefficients)
    evaluate_model(output_model)
    print('')



# Try regular merge with: true fishers, pef fishers, nmf approximation to the pef fishers.
# Try the fisher difference thing instead of top raw.

