R"""

# For efficients, I can have components restricted to each layer/variable and then
# learn coefficients over all of them. There should be a combinatorial increase
# the number of representable full-model components. Furthermore, we could use
# the learned coefficients to examine cross-layer relationships between components.



cd ~/Desktop/projects/extract_merge1
export PYTHONPATH=$PYTHONPATH:~/Desktop/projects/extract_merge1


python3 -i local_scripts/soc/soc_dev001.py



CUDA_VISIBLE_DEVICES=0 python -i local_scripts/soc/soc_dev001.py


"""
import collections
import dataclasses
import os
from importlib import reload
import itertools
from typing import Any, List, Sequence

import h5py
# import matplotlib.pyplot as plt
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import tensorflow as tf
import tensorflow_probability as tfp

from em.datasets import glue
from em.evaluation import evaluation
from em.fishers import diagonal
from em.fishers import per_example
from em.fishers import sparse_diagonal
from em.merging import merging
from em.models.generative import vae
# from em.models.generative import soc
from em.util import hf_util
from em.util import vat_da_faak_vpn

tfd = tfp.distributions

METRIC_VAR_FISHER = "bert_small_mnli_sparse_fisher_variances_32k.sp05.metric.131k.h5"
UNIFORM_VAR_FISHER = "bert_small_mnli_sparse_fisher_variances_32k.sp05.uniform.131k.h5"

if os.path.exists('/fruitbasket'):
    FISHER_DIR = '/fruitbasket/users/m/project_data/extract_merge1/fishers0'
else:
    FISHER_DIR = os.path.expanduser('~/Desktop/projects_data/extract_merge1/fishers0')

TASK = 'mnli'
MODEL = "prajjwal1/bert-small-mnli"
PRETRAINED_MODEL = "prajjwal1/bert-small"

SEQ_LEN = 128


###############################################################################
###############################################################################


def make_relu_layers(widths: List[int]) -> List[tf.keras.layers.Layer]:
    layers = []
    for w in widths[:-1]:
        layers.append(tf.keras.layers.Dense(w, activation='relu'))
    layers.append(tf.keras.layers.Dense(widths[-1], activation=None))
    return layers


def filter_out_embeddings_variables(named_variables: Sequence[tf.Variable], *other_lists: Sequence[Sequence[Any]]):
    # Make sure all inputs have the same length.
    assert len({len(named_variables)} | {len(a) for a in other_lists}) == 1
    # NOTE: I can probably generalize this later to accept a general VariableFilter object.
    inds = {i for i, v in enumerate(named_variables) if '/embeddings/' not in v.name}
    all_lists = [named_variables, *other_lists]
    ret = []
    for lst in all_lists:
        filtered = []
        for i, x in enumerate(lst):
            if i in inds:
                filtered.append(x)
        ret.append(filtered)
    return tuple(ret)


def compute_component_size(sparse_fishers: Sequence[tf.sparse.SparseTensor]) -> int:
    size = 0
    for x in sparse_fishers:
        size += x.values.shape[0]
    return size

# def make_encoder():
#     pass


# def make_decoder():
#     pass


class VaeSoc(tf.keras.Model):

    def __init__(
        self,
        encoder: tf.keras.Model,
        decoder: tf.keras.Model,
        # component_size: int,
        flat_fisher: tf.Tensor,
        n_components: int,
        vae_latent_size: int,
        beta: float = 1.0,
        non_negative_components: bool = True,
        non_negative_coefficients: bool = True,
        # TODO: Add option to constraint components to have unit norm.
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

        self.flat_fisher = flat_fisher
        component_size, = self.flat_fisher.shape

        self.vae_latent_size = vae_latent_size
        self.component_size = component_size
        self.n_components = n_components

        self.beta = beta

        self.non_negative_components = non_negative_components
        self.non_negative_coefficients = non_negative_coefficients

        self.vae = vae.Vae(
            encoder,
            decoder,
            representation_size=self.vae_latent_size,
            beta=self.beta,
        )

        self.precomponents = self.add_weight(
            name='precomponents',
            shape=[n_components, component_size],
            dtype=tf.float32,
            initializer='glorot_uniform',
            trainable=True,
        )

        # self.pre_output_variances = self.add_weight(
        #     name='pre_output_variances',
        #     shape=[component_size],
        #     dtype=tf.float32,
        #     initializer='glorot_uniform',
        #     trainable=True,
        # )
        self.pre_output_variances = tf.ones([component_size])

    def get_components(self):
        if self.non_negative_components:
            # return self.flat_fisher * tf.nn.softplus(self.precomponents) / float(self.component_size)
            # return tf.nn.softplus(self.precomponents - 7.0)
            # return tf.exp(self.precomponents / 3 - 10.0)
            # return tf.exp(self.precomponents + tf.math.log(self.flat_fisher)) / float(self.component_size)
            return tf.exp(self.precomponents + tf.math.log(self.flat_fisher)) / float(self.n_components)
        else:
            return tf.identity(self.precomponents)

    def call(self, x, training=None):
        precoeffs = self.vae(x, training=training)
        if self.non_negative_coefficients:
            coeffs = tf.nn.softplus(precoeffs)
        else:
            coeffs = precoeffs
        components = self.get_components()

        # Cache these so we can enforce priors on them in the loss.
        self._call_coeffs = coeffs
        self._call_components = components

        return coeffs @ components


class VaeSocLoss(tf.keras.losses.Loss):
    def __init__(self, model, **kwargs):
        super().__init__(**kwargs)
        self.model = model

    def call(self, y_true, y_pred):
        output_distr = tfd.MultivariateNormalDiag(
            loc=y_pred,
            scale_diag=tf.nn.softplus(self.model.pre_output_variances))
        # loss_recon = -tf.reduce_mean(output_distr.log_prob(y_true))
        loss_recon = -output_distr.log_prob(y_true) / tf.cast(tf.shape(y_pred)[0], tf.float32)
        print(loss_recon.shape)
        #
        posterior = self.model.vae.posterior
        prior = self.model.vae.prior
        loss_kl = tf.reduce_mean(tfd.kl_divergence(posterior, prior))
        #
        return loss_recon + self.model.beta * loss_kl


###############################################################################
###############################################################################

# PER_EXAMPLE_FISHER_BATCH_SIZE = 2
# SOC_BATCH_SIZE = 32

# PER_EXAMPLE_FISHER_BATCH_SIZE = 16
PER_EXAMPLE_FISHER_BATCH_SIZE = 256
SOC_BATCH_SIZE = 256

PREFETCH_FACTOR = 4

#######################################

sparse_fisher = sparse_diagonal.SparseDiagonalFisher.load(os.path.join(FISHER_DIR, METRIC_VAR_FISHER))
# sparse_fisher = sparse_diagonal.SparseDiagonalFisher.load(os.path.join(FISHER_DIR, UNIFORM_VAR_FISHER))

model = TFAutoModelForSequenceClassification.from_pretrained(MODEL, from_pt=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL)


model_variables = hf_util.get_mergeable_variables(model)
fishers = sparse_fisher.fishers

# Filter out embeddings variables cuz they chonky.
# model_variables, fishers = filter_out_embeddings_variables(model_variables, fishers)
# model_variables = model_variables[:2]
# fishers = fishers[:2]

# Downsize further, at least for local testing.
# model_variables = model_variables[::11]
# fishers = fishers[::11]

model_variables = model_variables[6:8]
fishers = fishers[6:8]


flat_fisher = tf.concat([f.values for f in fishers], axis=-1)

# meanish_fisher = tf.reduce_mean([tf.reduce_mean(f.values) for f in fishers]).numpy()
meanish_fisher = 1.0
# meanish_fisher = flat_fisher


# flat_fisher *= 1e3
# meanish_fisher = 1 / 1e3


component_size = compute_component_size(fishers)
print(component_size)


glue_ds = glue.load_glue_dataset(
    task=TASK,
    split='train',
    tokenizer=tokenizer,
    max_length=SEQ_LEN,
)
glue_ds = glue_ds.repeat().shuffle(1000)


def gen():
    for batch_fishers in per_example.stream_per_example_sparse_diagonal_fishers(
        model, glue_ds.batch(PER_EXAMPLE_FISHER_BATCH_SIZE), fishers, model_variables, unbatch=False
    ):
        yield tf.concat(batch_fishers, axis=-1) / meanish_fisher


# TODO: replace with something from per-example
prefetch_size = SOC_BATCH_SIZE * PREFETCH_FACTOR
soc_ds = tf.data.Dataset.from_generator(
    gen,
    output_signature=tf.TensorSpec(shape=[None, component_size], dtype=tf.float32),
)
soc_ds = soc_ds.unbatch().prefetch(prefetch_size).batch(SOC_BATCH_SIZE)
soc_ds = soc_ds.map(lambda x: (x, x))

#######################################

# beta = 1.0
beta = 0.0

vae_latent_size = 512

n_components = 512

soc_model = VaeSoc(
    # encoder=tf.keras.Sequential(make_relu_layers([512, 512, 2 * vae_latent_size])),
    # decoder=tf.keras.Sequential(make_relu_layers([512, 512, n_components])),
    encoder=tf.keras.Sequential(make_relu_layers([512, 2 * vae_latent_size])),
    decoder=tf.keras.Sequential(make_relu_layers([512, n_components])),
    flat_fisher=flat_fisher,
    n_components=n_components,
    vae_latent_size=vae_latent_size,
    beta=beta,
    non_negative_components=True,
    non_negative_coefficients=True,
)
loss = VaeSocLoss(soc_model)

# NOTE: I was getting negative loss values, which may or may not be wrong. IDK.

# lr = 1e-2
# lr = 1e-3
lr = 1e-4
# lr = 1e-5
soc_model.compile(
    # loss=loss,
    loss=tf.keras.losses.MeanSquaredError(),
    # loss=tf.keras.losses.MeanAbsoluteError(),
    optimizer=tf.keras.optimizers.Adam(lr),
    # optimizer=tf.keras.optimizers.SGD(lr),
)
soc_model.fit(soc_ds, steps_per_epoch=16, epochs=8)

# soc_model.fit(soc_ds, steps_per_epoch=64, epochs=64)

# TODO: Maybe see if there is some way to run on PCA-whiten subcomponets.
# Maybe I can ensure non-negativivity of components by running them in reverse
# through the PCA transform.


R"""

for _, x in soc_ds:
    break

y = soc_model(x)

y[:10,:5]

tf.reduce_all(y[0] == y[1]).numpy()


x[:10,:5]

"""