"""Common stuff for these development scripts."""
import collections
from typing import Any, List, Sequence

import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp

from em.models.generative import vae

tfd = tfp.distributions


###############################################################################
###############################################################################

def make_dense_layers(widths: List[int], activation='relu') -> List[tf.keras.layers.Layer]:
    layers = []
    for w in widths[:-1]:
        layers.append(tf.keras.layers.Dense(w, activation=activation))
    layers.append(tf.keras.layers.Dense(widths[-1], activation=None))
    return layers


def filter_out_embeddings_variables(named_variables: Sequence[tf.Variable], *other_lists: Sequence[Sequence[Any]]):
    # Make sure all inputs have the same length.
    assert len({len(named_variables)} | {len(a) for a in other_lists}) == 1
    # NOTE: I can probably generalize this later to accept a general VariableFilter object.
    inds = {i for i, v in enumerate(named_variables) if '/embeddings/' not in v.name}
    all_lists = [named_variables, *other_lists]
    ret = []
    for lst in all_lists:
        filtered = []
        for i, x in enumerate(lst):
            if i in inds:
                filtered.append(x)
        ret.append(filtered)
    return tuple(ret)


def compute_component_size(sparse_fishers: Sequence[tf.sparse.SparseTensor]) -> int:
    size = 0
    for x in sparse_fishers:
        size += x.values.shape[0]
    return size


###############################################################################
###############################################################################

class AeSoc(tf.keras.Model):
    """Autoencoder sum of components."""

    def __init__(
        self,
        autoencoder: tf.keras.Model,
        component_size: int,
        n_components: int,
        #
        non_negative_components: bool = True,
        non_negative_coefficients: bool = True,
        unit_norm_components: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.autoencoder = autoencoder

        self.component_size = component_size
        self.n_components = n_components

        self.non_negative_components = non_negative_components
        self.non_negative_coefficients = non_negative_coefficients
        self.unit_norm_components = unit_norm_components

        self.precomponents = self.add_weight(
            name='precomponents',
            shape=[n_components, component_size],
            dtype=tf.float32,
            initializer='glorot_uniform',
            trainable=True,
        )

    def get_components(self):
        if self.non_negative_components:
            components = tf.exp(self.precomponents) / float(self.n_components)
        else:
            components = tf.identity(self.precomponents)
        if self.unit_norm_components:
            components = tf.linalg.l2_normalize(components, axis=-1)
        return components

    def call(self, x, training=None):
        precoeffs = self.autoencoder(x, training=training)
        if self.non_negative_coefficients:
            coeffs = tf.nn.softplus(precoeffs)
        else:
            coeffs = precoeffs
        components = self.get_components()

        # Cache these so we can enforce priors on them in the loss.
        self._call_coeffs = coeffs
        self._call_components = components

        return coeffs @ components


class VaeSocLoss(tf.keras.losses.Loss):
    def __init__(self, model, **kwargs):
        super().__init__(**kwargs)
        self.model = model

    def call(self, y_true, y_pred):
        output_distr = tfd.MultivariateNormalDiag(
            loc=y_pred,
            scale_diag=tf.nn.softplus(self.autoencoder.pre_output_variances))
        # loss_recon = -tf.reduce_mean(output_distr.log_prob(y_true))
        loss_recon = -output_distr.log_prob(y_true) / tf.cast(tf.shape(y_pred)[0], tf.float32)
        print(loss_recon.shape)
        #
        posterior = self.model.autoencoder.posterior
        prior = self.model.autoencoder.prior
        loss_kl = tf.reduce_mean(tfd.kl_divergence(posterior, prior))
        #
        return loss_recon + self.model.beta * loss_kl


class SparsityLoss(tf.keras.losses.Loss):
    def __init__(
        self,
        recon_loss: tf.keras.losses.Loss,
        model: AeSoc,
        lmbda_comp: float,
        lmbda_coeff: float,
        lmbda_comp_unit_norm: float = 0.0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.recon_loss = recon_loss
        self.model = model
        self.lmbda_comp = lmbda_comp
        self.lmbda_coeff = lmbda_coeff
        self.lmbda_comp_unit_norm = lmbda_comp_unit_norm

    def call(self, y_true, y_pred):
        recon_loss = self.recon_loss(y_true, y_pred)

        comp_loss = self.lmbda_comp * tf.reduce_mean(tf.abs(self.model._call_components))
        coeff_loss = self.lmbda_coeff * tf.reduce_mean(tf.abs(self.model._call_coeffs))

        # TODO: Maybe replace norm with squared norm.
        comp_norms = tf.linalg.norm(self.model._call_components, axis=-1, keepdims=True)
        comp_norm_loss = self.lmbda_comp_unit_norm * tf.reduce_mean((comp_norms - 1)**2)

        return recon_loss + comp_loss + coeff_loss + comp_norm_loss

###############################################################################
###############################################################################
# Probably some (Fashion)MNIST-specific stuff here.


def imshow_mnist(x):
    x = tf.reshape(x, [28, 28])
    plt.imshow(x, cmap='gray')
    plt.show()


def plot_component(soc_model: AeSoc, index: int):
    comp = soc_model.get_components()[index]
    imshow_mnist(comp)


def imshow_mnist_multi(x, row_size: int):
    x = tf.reshape(x, [-1, 28, 28])
    n_images = x.shape[0]
    n_rows = n_images // row_size
    if n_images % row_size:
        n_rows += 1
    n_cols = row_size

    fig, axs = plt.subplots(n_rows, n_cols)
    for i in range(n_images):
        row, col = divmod(i, row_size)
        axs[row, col].axis('off')
        axs[row, col].imshow(x[i], cmap='gray')
    plt.tight_layout()

    plt.show()

###############################################################################
###############################################################################


_HfLogitsReturn = collections.namedtuple('HfLogitsReturn', ['logits'])


class HfLogitsSequential(tf.keras.Sequential):

    def call(self, *args, **kwargs):
        logits = super().call(*args, **kwargs)
        return _HfLogitsReturn(logits=logits)

# def hack_call_to_be_like_hf(model: tf.keras.Model):
#     call = model.__call__

#     def hacked_and_whacked(*args, **kwargs):
#         logits = call(*args, **kwargs)
#         return _HfLogitsReturn(logits=logits)

#     model.__call__ = hacked_and_whacked

#     return model
