R"""Script for actually merging models.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1

EVAL_TASK=rte
RTE_MODEL=textattack/roberta-base-RTE
MNLI_MODEL=textattack/roberta-base-MNLI
FISHER_DIR=~/Desktop/projects_data/extract_merge1/fishers0

# Fisher merge
python3 ./scripts1/sparse/sparse_merge_and_evaluate.py  \
    --base_model=$RTE_MODEL \
    --base_fisher=$FISHER_DIR/rte_fisher.sparse.03.md.h5 \
    --fishers=$FISHER_DIR/mnli_fisher.sparse.03.md.h5 \
    --glue_task=$EVAL_TASK


"""
import os

from absl import app
from absl import flags
from absl import logging
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer

from em.datasets import glue
from em.evaluation import evaluation
from em.merging import merging
from em.merging import sparse_merging
from em.fishers import diagonal
from em.fishers import sparse_diagonal
from em.models import transformer_model_vars as tmv
from em.util import hdf5_util
from em.util import vat_da_faak_vpn


FLAGS = flags.FLAGS


flags.DEFINE_string("base_model", None, "")
flags.DEFINE_string("base_fisher", None, "")
flags.DEFINE_string("glue_task", None, "")

flags.DEFINE_list("fishers", None, "")

flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("split", "validation", "")
flags.DEFINE_integer("n_examples", 4096, "")
flags.DEFINE_integer("batch_size", 32, "")
flags.DEFINE_integer("sequence_length", 128, "")

flags.DEFINE_integer("n_coeffs", 51, "")
flags.DEFINE_enum("coeff_mode", "grid", ["grid", "random"], "")

flags.DEFINE_float("fisher_floor", 1e-6, "")
flags.DEFINE_bool("favor_target_model", True, "")
flags.DEFINE_bool("normalize_fishers", True, "")

tmv.add_variable_filter_flags()


def load_fishers():
    fishers = []
    for fisher_str in FLAGS.fishers:
        fisher_str = os.path.expanduser(fisher_str)
        fisher = sparse_diagonal.SparseDiagonalFisher.load(fisher_str)
        fishers.append(fisher)
    return fishers


def get_coeffs_set():
    n_models = len(FLAGS.fishers) + 1
    if FLAGS.coeff_mode == "grid":
        assert n_models == 2
        return merging.create_pairwise_grid_coeffs(FLAGS.n_coeffs)
    elif FLAGS.coeff_mode == "random":
        return merging.create_random_coeffs(n_models, FLAGS.n_coeffs)
    else:
        raise ValueError


def get_best_results(results):
    return max(results, key=lambda r: evaluation.average_score(r.score))


def main(_):
    base_model_str = os.path.expanduser(FLAGS.base_model)
    base_model = TFAutoModelForSequenceClassification.from_pretrained(
        base_model_str, from_pt=FLAGS.from_pt
    )
    tokenizer = AutoTokenizer.from_pretrained(base_model_str)

    base_fisher_str = os.path.expanduser(FLAGS.base_fisher)
    base_fisher = sparse_diagonal.SparseDiagonalFisher.load(base_fisher_str)

    fishers = load_fishers()

    ds = glue.load_glue_dataset(
        task=FLAGS.glue_task,
        split=FLAGS.split,
        tokenizer=tokenizer,
        max_length=FLAGS.sequence_length,
    )
    ds = ds.take(FLAGS.n_examples).batch(FLAGS.batch_size)

    metric = evaluation.load_metric_for_glue_task(FLAGS.glue_task)

    coefficients_set = get_coeffs_set()

    variable_filter = tmv.get_variable_filter_from_flags()

    results = sparse_merging.merging_coefficients_search(
        base_model=base_model,
        base_model_fisher=base_fisher,
        fishers=fishers,
        coefficients_set=coefficients_set,
        dataset=ds,
        metric=metric,
        variable_filter=variable_filter,
        fisher_floor=FLAGS.fisher_floor,
        favor_target_model=FLAGS.favor_target_model,
        normalize_fishers=FLAGS.normalize_fishers,
        print_results=True,
    )

    best = get_best_results(results)
    print(80 * "*")
    print(" Best Merge")
    print(80 * "*")
    merging.print_merge_result(best)


if __name__ == "__main__":
    app.run(main)
