R"""See the variance of per-example diagonal Fishers WRT to a sparse subset of parameters.

# NOTE: Just remembered that the mean should be the Fisher value. I'll also maybe want to
# compare variances vs mean fisher value and/or the metric quantity.
#
# Look at sqrt of variances divided by Fisher value.


"""

import os

from absl import app
from absl import flags
from absl import logging

import h5py

from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm

from em.datasets import glue
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import sparse_diagonal
from em.tools import welford_variance
from em.util import hdf5_util
from em.util import hf_util
from em.util import vat_da_faak_vpn


_METHODS = [
    'uniform',
    'metric_derived',
]


FLAGS = flags.FLAGS


flags.DEFINE_string("model", None, "")
flags.DEFINE_string("pretrained_model", None, "")
flags.DEFINE_string("dense_fisher", None, "")

flags.DEFINE_string("output_path", None, "Path of hdf5 file to save output to.")

flags.DEFINE_bool("from_pt", True, "")

flags.DEFINE_string("glue_task", None, "")

flags.DEFINE_string("split", "train", "")
flags.DEFINE_integer("n_examples", 4096, "")
flags.DEFINE_integer("batch_size", 2, "")
flags.DEFINE_integer("sequence_length", 128, "")


flags.DEFINE_enum("sparsification_method", None, _METHODS, "")
flags.DEFINE_float('sparsity', None, 'Fraction of parameters to keep. Must be between 0 and 1.')


def _sparsify_uniform(
        dense_fisher: diagonal.DiagonalFisher, finetuned_model) -> sparse_diagonal.SparseDiagonalFisher:
    return sparse_diagonal.from_dense_uniformly(
        dense_fisher,
        hf_util.get_mergeable_variables(finetuned_model),
        FLAGS.sparsity
    )


def _sparsify_metric_derived(
        dense_fisher: diagonal.DiagonalFisher, finetuned_model) -> sparse_diagonal.SparseDiagonalFisher:

    pretrained_model = TFAutoModelForSequenceClassification.from_pretrained(
        os.path.expanduser(FLAGS.pretrained_model), from_pt=FLAGS.from_pt)

    return sparse_diagonal.from_dense_by_metric_approximation(
        dense_fisher,
        hf_util.get_mergeable_variables(finetuned_model),
        hf_util.get_mergeable_variables(pretrained_model),
        FLAGS.sparsity,
    )


def _sparsify(
        dense_fisher: diagonal.DiagonalFisher, finetuned_model) -> sparse_diagonal.SparseDiagonalFisher:
    if FLAGS.sparsification_method == 'uniform':
        return _sparsify_uniform(dense_fisher, finetuned_model)
    elif FLAGS.sparsification_method == 'metric_derived':
        return _sparsify_metric_derived(dense_fisher, finetuned_model)
    else:
        raise ValueError(FLAGS.sparsification_method)


def _write_variances_to_file(sparse_fisher_path, variances):
    with h5py.File(sparse_fisher_path, "r+") as f:
        group = f['data'].create_group('variances')
        for i, v in enumerate(variances):
            ds = group.create_dataset(str(i), v.shape, dtype=v.dtype)
            hdf5_util.set_h5_ds(ds, v)


def main(_):
    model_str = os.path.expanduser(FLAGS.model)
    model = TFAutoModelForSequenceClassification.from_pretrained(
        model_str, from_pt=FLAGS.from_pt
    )

    dense_fisher_str = os.path.expanduser(FLAGS.dense_fisher)
    dense_fisher = diagonal.DiagonalFisher.load(dense_fisher_str)

    # TODO: Find way set option to filter the variables that we consider
    # when sparsifying.
    sparse_fisher = _sparsify(dense_fisher, model)
    model_variables = hf_util.get_mergeable_variables(model)

    tokenizer = AutoTokenizer.from_pretrained(model_str)
    ds = glue.load_glue_dataset(
        task=FLAGS.glue_task,
        split=FLAGS.split,
        tokenizer=tokenizer,
        max_length=FLAGS.sequence_length,
    )
    ds = ds.take(FLAGS.n_examples).batch(FLAGS.batch_size)

    variance_accs = [
        welford_variance.VarianceAccumulator()
        for _ in model_variables
    ]

    gen = per_example.stream_per_example_sparse_diagonal_fishers(
        model, ds, sparse_fisher.fishers, model_variables, unbatch=False
    )
    for pe_fisher in tqdm(gen, total=FLAGS.n_examples // FLAGS.batch_size):
        for acc, f in zip(variance_accs, pe_fisher):
            acc.batch_update(f)

    variances = [acc.variance.numpy() for acc in variance_accs]

    output_path = os.path.expanduser(FLAGS.output_path)
    sparse_fisher.save(output_path)
    _write_variances_to_file(output_path, variances)


if __name__ == "__main__":
    app.run(main)
