R"""Script for the first pass of experiments exploring HANS ablation.


===============================================================================
Some stuffs:
- See how much the model’s performance on MNLI is affected by the ablations/merges.
- Multi component ablation (how to combine their component Fishers)

Baselines:
- Move in direction of postive/negative gradient for a component's top examples.
- Use the full fisher (or PEF/NMF-derived full fisher) for a component's top examples for the merge.
- Isotropic merge restricted to subset of parameters present in a component.
- Explore choices of the "other" Fisher used in the merge.
    - Full dataset fisher.
    - MNLI vs HANS vs combo fisher.
    - PEF-derived fisher.
    - NMF-reconstructed Fisher.
    - Whether to include the ablated component or not in the Fisher.
- Maybe something to do with the loss sign guide.

"""
import dataclasses
import os
from typing import Sequence

from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, PreTrainedTokenizer

from em import datasets as em_datasets
from em.fishers import diagonal
from em.merging import merging
from em.util import hf_util

from em.projects.ll import hans_components_context as HCC
from em.projects.ll import hans_merging_context as HMC

###############################################################################

# TODO: Hardcoding some stuff now, maybe make settable later.
EXPS_DIR = '/fruitbasket/users/m/project_data/extract_merge1/ll1'
MODELS_DIR = os.path.join(EXPS_DIR, 'models')
FISHERS_DIR = os.path.join(EXPS_DIR, 'fishers')
PER_EXAMPLES_FISHERS_DIR = os.path.join(EXPS_DIR, 'per_example_fishers')

###############################################################################

FLAGS = flags.FLAGS

flags.DEFINE_integer("model_number", None, '')
flags.DEFINE_integer("n_fisher_values", None, '')

flags.DEFINE_bool('with_flipped', True, '')


flags.DEFINE_integer("n_mnli_examples", 3072, '')

flags.DEFINE_integer("sequence_length", 64, '')
flags.DEFINE_integer("eval_batch_size", 128, '')
flags.DEFINE_integer("get_loss_gradient_batch_size", 16, '')
flags.DEFINE_integer("dense_fisher_batch_size", 4, '')


flags.DEFINE_integer("component_index", None, '')

# TODO: Maybe make these lists.
flags.DEFINE_integer("n_top_examples_sign_guide", None, '')
flags.DEFINE_float("delta", None, '')

flags.DEFINE_float("fisher_floor", 1e-7, '')
flags.DEFINE_integer("n_coefficients", None, '')


flags.DEFINE_string("tokenizer", 'bert-base-uncased', '')

flags.DEFINE_string('pef_pattern', "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.10k.131072.h5", '')
flags.DEFINE_string('nmf_pattern', "spH.nmf_decomp.c1024_2kIters_{n_fisher_values}pe.{pef}", '')
flags.DEFINE_string('fisher_pattern', "feather_berts_{model_number}.hans_lone_with_flipped.all_vars.h5", '')

flags.DEFINE_string('model_pattern', 'connectivity/feather_berts_{model_number}', '')


def load_component_context(tokenizer):
    # NOTE: Kinda hacky.
    pef_pattern = FLAGS.pef_pattern
    nmf_pattern = FLAGS.nmf_pattern.replace('{pef}', pef_pattern)
    return HCC.HansLoneComponentContext(
        model_name_pattern=FLAGS.model_pattern,
        pef_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, pef_pattern),
        nmf_filepath_pattern=os.path.join(PER_EXAMPLES_FISHERS_DIR, nmf_pattern),
        with_flipped=FLAGS.with_flipped,
        tokenizer=tokenizer,
    )


def get_sign_guide(mc, component_index):
    sorted_example_indices = mc.sort_example_indices_for_component(component_index)

    sg_ds = em_datasets.load('hans/lexical_overlap_ne_with_flipped', split='validation',
                             sequence_length=FLAGS.sequence_length, tokenizer=mc.tokenizer)
    sg_ds = HMC.get_ds_by_example_indices(sg_ds, sorted_example_indices[:FLAGS.n_top_examples_sign_guide])
    sg_ds = sg_ds.batch(FLAGS.get_loss_gradient_batch_size)

    return HMC.get_loss_gradient(
        mc.model,
        mc.variables,
        sg_ds,
    )


def load_mnli(tokenizer):
    ds = em_datasets.load('glue/mnli', split='validation', sequence_length=FLAGS.sequence_length, tokenizer=tokenizer)
    ds = ds.take(FLAGS.n_mnli_examples).cache()
    return ds

# def get_top_examples_fishers(mc, component_index)

###############################################################################
###############################################################################

###############################################################################
###############################################################################


def main(_):
    tokenizer = AutoTokenizer.from_pretrained(FLAGS.tokenizer)

    hacc = load_component_context(tokenizer)
    mc = hacc.make_model_context('a', model_number=FLAGS.model_number, n_fisher_values=FLAGS.n_fisher_values)

    eval_ctx = hacc.get_evaluation_context(og_logits=mc.container.predicted_logits)

    ############################################################

    component_index = FLAGS.component_index

    # # This equal to the gradient on the subset of component top examples.
    # sign_guide = get_sign_guide(mc, component_index)

    ############################################################

    # fisher1 = mc.make_fisher_for_components(set(range(mc.n_components)) - {component_index})
    # fisher2 = mc.make_fisher_for_components([component_index])

    # variables1 = list(mc.variables)
    # variables2 = HMC.apply_sign_guide(mc.variables, sign_guide, FLAGS.delta)



    # sorted_example_indices = mc.sort_example_indices_for_component(component_index)


if __name__ == "__main__":
    app.run(main)
