R"""


cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/misc1/pe_fisher_variance_dev001.py


"""
import collections
import dataclasses
import os
from importlib import reload
import itertools
from typing import List

import h5py
import matplotlib.pyplot as plt
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf

from em.datasets import glue
from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import sparse_diagonal
from em.merging import merging
from em.tools import welford_variance
from em.util import hf_util
from em.util import vat_da_faak_vpn

METRIC_VAR_FISHER = "bert_small_mnli_sparse_fisher_variances_32k.sp05.metric.131k.h5"
UNIFORM_VAR_FISHER = "bert_small_mnli_sparse_fisher_variances_32k.sp05.uniform.131k.h5"

FISHER_DIR = os.path.expanduser('~/Desktop/projects_data/extract_merge1/fishers0')

MODEL = "prajjwal1/bert-small-mnli"
PRETRAINED_MODEL = "prajjwal1/bert-small"


def _flatten(lst: List[tf.Tensor]) -> tf.Tensor:
    return tf.concat([tf.reshape(x, [-1]) for x in lst], axis=0)


@dataclasses.dataclass
class Data:
    parameters: List[tf.Tensor]
    deltas: List[tf.Tensor]
    fishers: List[tf.Tensor]
    variances: List[tf.Tensor]

    def __post_init__(self):
        self.flat_parameters = _flatten(self.parameters)
        self.flat_deltas = _flatten(self.deltas)
        self.flat_fishers = _flatten(self.fishers)
        self.flat_variances = _flatten(self.variances)


def read_from_file(filename: str, pretrained_model):
    filepath = os.path.join(FISHER_DIR, filename)
    sparse_fisher = sparse_diagonal.SparseDiagonalFisher.load(filepath)
    #
    pretrained_variables = hf_util.get_mergeable_variables(pretrained_model)
    deltas = [
        (p.values - tf.gather_nd(v, p.indices))**2
        for v, p in zip(pretrained_variables, sparse_fisher.parameters)
    ]
    #
    # We don't really case about sparse structure here, so we just take the values.
    parameters = [p.values for p in sparse_fisher.parameters]
    fishers = [f.values for f in sparse_fisher.fishers]
    #
    variances = []
    with h5py.File(filepath, "r") as f:
        var_group = f['data/variances']
        n_variables = f['data'].attrs["n_variables"]
        for i in range(n_variables):
            i = str(i)
            var_values = tf.cast(var_group[i], dtype=tf.float32)
            variances.append(var_values)
    return Data(
        parameters=parameters,
        deltas=deltas,
        fishers=fishers,
        variances=variances,
    )


pretrained_model = TFAutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL, from_pt=True)

metric_data = read_from_file(METRIC_VAR_FISHER, pretrained_model)
uniform_data = read_from_file(UNIFORM_VAR_FISHER, pretrained_model)


# mvs = tf.sort(metric_data.flat_variances, direction="DESCENDING")
# uvs = tf.sort(uniform_data.flat_variances, direction="DESCENDING")

# mvs = tf.sort(tf.sqrt(metric_data.flat_variances) / metric_data.flat_fishers, direction="DESCENDING")
# uvs = tf.sort(tf.sqrt(uniform_data.flat_variances) / uniform_data.flat_fishers, direction="DESCENDING")


mvs = tf.sort(tf.sqrt(metric_data.flat_variances) * metric_data.flat_deltas, direction="DESCENDING")
uvs = tf.sort(tf.sqrt(uniform_data.flat_variances) * uniform_data.flat_deltas, direction="DESCENDING")

plt.plot(tf.math.log(mvs), label='metric')
plt.plot(tf.math.log(uvs), label='uniform')
# plt.plot(mvs, label='metric')
# plt.plot(uvs, label='uniform')
plt.legend()
plt.show()
