R"""Computes and saves a sparse version of a diagonal Fisher.


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


RTE_MODEL=textattack/roberta-base-RTE
MNLI_MODEL=textattack/roberta-base-MNLI
BASE_MODEL=roberta-base
FISHER_DIR=~/Desktop/projects_data/extract_merge1/fishers0

python3 scripts1/sparse/sparsify_diagonal_fisher.py \
    --fisher_path="$FISHER_DIR/rte_fisher.h5" \
    --output_path="$FISHER_DIR/rte_fisher.sparse.03.md.h5" \
    --method=metric_derived \
    --sparsity=0.03 \
    --finetuned_model=$RTE_MODEL \
    --pretrained_model=$BASE_MODEL

python3 scripts1/sparse/sparsify_diagonal_fisher.py \
    --fisher_path="$FISHER_DIR/mnli_fisher.h5" \
    --output_path="$FISHER_DIR/mnli_fisher.sparse.03.uniformm.h5" \
    --method=uniform \
    --sparsity=0.03 \
    --finetuned_model=$MNLI_MODEL \
    --pretrained_model=$BASE_MODEL

"""

import os

from absl import app
from absl import flags
from absl import logging

from transformers import TFAutoModelForSequenceClassification

from em.fishers import diagonal
from em.fishers import sparse_diagonal
from em.util import hf_util
from em.util import vat_da_faak_vpn

FLAGS = flags.FLAGS

_METHODS = [
    'uniform',
    'metric_derived',
]

# TODO: Add descriptions to flags
flags.DEFINE_string("fisher_path", None, "Path of hdf5 file to read Fisher from.")

flags.DEFINE_string("output_path", None, "Path of hdf5 file to save sparse Fisher to.")


flags.DEFINE_enum("method", None, _METHODS, "")

flags.DEFINE_float('sparsity', None, 'Fraction of parameters to keep. Must be between 0 and 1.')

flags.DEFINE_string("finetuned_model", None, "")
flags.DEFINE_bool("from_pt_finetuned", True, "")

# These are only need if method = "metric_derived".
flags.DEFINE_string("pretrained_model", None, "")
flags.DEFINE_bool("from_pt_pretrained", True, "")


flags.mark_flags_as_required(['fisher_path', 'output_path', 'method', 'sparsity', 'finetuned_model'])


def _validate_flags():
    assert 0 <= FLAGS.sparsity <= 1
    if FLAGS.method == 'metric_derived':
        assert FLAGS.finetuned_model is not None
        assert FLAGS.pretrained_model is not None


def _sparsify_uniform(
        dense_fisher: diagonal.DiagonalFisher, finetuned_model) -> sparse_diagonal.SparseDiagonalFisher:
    return sparse_diagonal.from_dense_uniformly(
        dense_fisher,
        hf_util.get_mergeable_variables(finetuned_model),
        FLAGS.sparsity
    )


def _sparsify_metric_derived(
        dense_fisher: diagonal.DiagonalFisher, finetuned_model) -> sparse_diagonal.SparseDiagonalFisher:

    pretrained_model = TFAutoModelForSequenceClassification.from_pretrained(
        os.path.expanduser(FLAGS.pretrained_model), from_pt=FLAGS.from_pt_pretrained)

    return sparse_diagonal.from_dense_by_metric_approximation(
        dense_fisher,
        hf_util.get_mergeable_variables(finetuned_model),
        hf_util.get_mergeable_variables(pretrained_model),
        FLAGS.sparsity,
    )


def _sparsify(
        dense_fisher: diagonal.DiagonalFisher, finetuned_model) -> sparse_diagonal.SparseDiagonalFisher:
    if FLAGS.method == 'uniform':
        return _sparsify_uniform(dense_fisher, finetuned_model)
    elif FLAGS.method == 'metric_derived':
        return _sparsify_metric_derived(dense_fisher, finetuned_model)
    else:
        raise ValueError(FLAGS.method)


def main(_):
    _validate_flags()

    dense_fisher = diagonal.DiagonalFisher.load(os.path.expanduser(FLAGS.fisher_path))
    finetuned_model = TFAutoModelForSequenceClassification.from_pretrained(
        os.path.expanduser(FLAGS.finetuned_model), from_pt=FLAGS.from_pt_finetuned)
    sparse_fisher = _sparsify(dense_fisher, finetuned_model)

    output_path = os.path.expanduser(FLAGS.output_path)
    sparse_fisher.save(output_path)


if __name__ == "__main__":
    app.run(main)
