R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/soc/soc_speed_dev001.py



CUDA_VISIBLE_DEVICES=0 python -i local_scripts/soc/soc_speed_dev001.py


"""
import collections
import dataclasses
import os
from importlib import reload
import itertools
import time
from typing import Any, List, Sequence

import h5py
# import matplotlib.pyplot as plt
import numpy as np
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf
import tensorflow_probability as tfp

from em.datasets import glue
from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import sparse_diagonal
from em.merging import merging
from em.models.generative import vae
# from em.models.generative import soc
from em.util import hf_util
from em.util import vat_da_faak_vpn

tfd = tfp.distributions

METRIC_VAR_FISHER = "bert_small_mnli_sparse_fisher_variances_32k.sp05.metric.131k.h5"
UNIFORM_VAR_FISHER = "bert_small_mnli_sparse_fisher_variances_32k.sp05.uniform.131k.h5"

if os.path.exists('/fruitbasket'):
    FISHER_DIR = '/fruitbasket/users/m/project_data/extract_merge1/fishers0'
else:
    FISHER_DIR = os.path.expanduser('~/Desktop/projects_data/extract_merge1/fishers0')

TASK = 'mnli'
MODEL = "prajjwal1/bert-small-mnli"
PRETRAINED_MODEL = "prajjwal1/bert-small"

SEQ_LEN = 128

###############################################################################
###############################################################################


def compute_component_size(sparse_fishers: Sequence[tf.sparse.SparseTensor]) -> int:
    size = 0
    for x in sparse_fishers:
        size += x.values.shape[0]
    return size


###############################################################################
###############################################################################

# PER_EXAMPLE_FISHER_BATCH_SIZE = 128
PER_EXAMPLE_FISHER_BATCH_SIZE = 256
# PER_EXAMPLE_FISHER_BATCH_SIZE = 64
# PER_EXAMPLE_FISHER_BATCH_SIZE = 16
# PER_EXAMPLE_FISHER_BATCH_SIZE = 32

# PER_EXAMPLE_FISHER_BATCH_SIZE = 2

#######################################

sparse_fisher = sparse_diagonal.SparseDiagonalFisher.load(os.path.join(FISHER_DIR, METRIC_VAR_FISHER))

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL)


model_variables = hf_util.get_mergeable_variables(model)
fishers = sparse_fisher.fishers

# model_variables = model_variables[6:8]
# fishers = fishers[6:8]

model_variables = model_variables[6:12]
fishers = fishers[6:12]

flat_fisher = tf.concat([f.values for f in fishers], axis=-1)


component_size = compute_component_size(fishers)

print(component_size)


glue_ds = glue.load_glue_dataset(
    task=TASK,
    split='train',
    tokenizer=tokenizer,
    max_length=SEQ_LEN,
)
glue_ds = glue_ds.repeat().shuffle(1000)


def something(n_iters):
    # gen = per_example.stream_per_example_sparse_diagonal_fishers(
    #     model, glue_ds.batch(PER_EXAMPLE_FISHER_BATCH_SIZE), fishers, model_variables, unbatch=False
    # )
    gen = per_example.stream_per_example_diagonal_fishers(
        model, glue_ds.batch(PER_EXAMPLE_FISHER_BATCH_SIZE), model_variables, unbatch=False
    )
    times = []
    for batch_fishers in itertools.islice(gen, n_iters):
        times.append(time.time())
    return np.array(times)


n_iters = 16
times = something(n_iters)
print(np.diff(times))


R"""

Sparse:
batch_size=256 => ~500ms
batch_size=128 => ~250ms


Dense:
batch_size=256 => ~500ms

"""
