R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/misc1/per_example_fisher_dev001.py


"""

import os
from importlib import reload
import itertools

from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf

from em.datasets import glue
from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import sparse_diagonal
from em.merging import merging
from em.tools import welford_variance
from em.util import hf_util
from em.util import vat_da_faak_vpn

###############################################################################

acc = welford_variance.VarianceAccumulator()
samples = tf.random.normal([10000, 2048], stddev=2.0)
acc.batch_update(samples)
print(acc.variance.numpy())

# def random_sparse_tensors(variables):
#     ret = []
#     for v in variables:
#         mask = tf.random.normal(tf.shape(v)) > 0
#         sparse = tf.sparse.from_dense(tf.cast(mask, tf.float32))
#         ret.append(sparse)
#     return ret


# task = 'mnli'
# model_str = 'prajjwal1/bert-small-mnli'
# model = TFAutoModelForSequenceClassification.from_pretrained(model_str, from_pt=True)
# tokenizer = AutoTokenizer.from_pretrained(model_str)

# variables = [
#     v for v in hf_util.get_mergeable_variables(model) 
#     if '/embeddings/' not in v.name
# ]


# ds = glue.load_glue_dataset(
#     task=task,
#     split='train',
#     tokenizer=tokenizer,
#     max_length=128,
# )
# ds = ds.batch(2)

# sparse_indices = random_sparse_tensors(variables)


# # gen = per_example.stream_per_example_diagonal_fishers(model, ds, unbatch=True)
# gen = per_example.stream_per_example_sparse_diagonal_fishers(model, ds, sparse_indices, variables, unbatch=True)
# for fisher in itertools.islice(gen, 10):
#     print(fisher)
