from collections.abc import Sequence, Mapping

import tensorflow as tf
import tensorflow_probability as tfp

from typing import TYPE_CHECKING, Union, Any

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

if TYPE_CHECKING:  # Support auto-completion in IDEs.
    from keras.api._v2 import keras
    from tensorflow_probability.python.distributions import Normal, Bernoulli
else:
    from tensorflow import keras
    Normal = tfp.distributions.Normal
    Bernoulli = tfp.distributions.Bernoulli

from ..modules import GRUD, GRUDInput, GRUDState, DecayInterpolate, InterpolateInput

from .. import functional as F


__all__ = [
    "SupNotMIWAEModel",
]


ImputeType = Literal["mean", "single", "multiple", "mean-decay", "single-decay", "multiple-decay"]


class SupNotMIWAEModel(keras.Model):
    def __init__(
        self,
        output_activation,
        output_dims,
        n_train_latents: int = 10,
        n_train_samples: int = 1,
        n_test_latents: int = 20,
        n_test_samples: int = 30,
        n_hidden: int = 128,
        n_units: int = 128,
        z_dim: int = 32,
        dropout: float = 0.,
        recurrent_dropout: float = 0.,
        observe_dropout: Union[float, Sequence[float], Mapping[Any, float]] = 0.,  # Supports feature-wise dropout
        prior_type: Literal["standard", "autoregressive"] = "autoregressive",
        encoder_type: Literal["mlp", "cnn", "gru", "grud"] = "grud",
        decoder_type: Literal["mlp", "cnn", "gru", "grud"] = "gru",
        classifier_type: Literal["gru", "grud"] = "gru",
        missing_type: Literal["mar", "mnar"] = "mar",
        impute_type: ImputeType = "multiple-decay",
        min_latent_sigma: float = 0.,
        min_sigma: float = 0.,
        objective: str = "iwae",
        sparsity_normalize: bool = False,
    ):
        self._config = {k: v for k, v in locals().items() if k not in ["self", "__class__"]}
        super().__init__()

        if isinstance(output_dims, Sequence):
            # We have an online prediction scenario
            assert output_dims[0] is None
            self.return_sequences = True
            output_dims = output_dims[1]
        else:
            self.return_sequences = False

        if isinstance(observe_dropout, Mapping):
            # Support sweep argument inputs for W&B
            observe_dropout = tf.constant([[
                [v for _, v in sorted(observe_dropout.items())]
            ]], dtype=tf.float32)
        elif isinstance(observe_dropout, Sequence):
            observe_dropout = tf.constant([[observe_dropout]], dtype=tf.float32)
        else:
            observe_dropout = tf.constant(min(1., max(0., observe_dropout)))

        if missing_type not in ["mar", "mnar"]:
            raise ValueError(f"Unknown missing type: {missing_type}")

        self.output_activation = output_activation
        self.output_dims = output_dims
        self.n_train_latents = n_train_latents
        self.n_train_samples = n_train_samples
        self.n_test_latents = n_test_latents
        self.n_test_samples = n_test_samples
        self.n_hidden = n_hidden
        self.n_units = n_units
        self.z_dim = z_dim
        self.dropout = min(1., max(0., dropout))
        self.recurrent_dropout = min(1., max(0., recurrent_dropout))
        self.observe_dropout = observe_dropout
        self.prior_type = prior_type
        self.encoder_type = encoder_type
        self.decoder_type = decoder_type
        self.classifier_type = classifier_type
        self.missing_type = missing_type
        self.impute_type = impute_type
        self.min_latent_sigma = min_latent_sigma
        self.min_sigma = min_sigma
        self.objective = objective
        self.sparsity_normalize = sparsity_normalize

    def build(self, input_shape):
        # See `call` for the expected shape of the inputs.
        _, _, values_shape, _, _ = input_shape

        self.x_dim = values_shape[-1]

        # === Encoder ===

        if self.encoder_type == "mlp":
            self.encoder = keras.Sequential([
                keras.layers.Dense(self.n_hidden, activation=tf.nn.tanh, name="encoder/dense1"),
                keras.layers.Dense(self.n_hidden, activation=tf.nn.tanh, name="encoder/dense2"),
            ], name="encoder")
        elif self.encoder_type == "cnn":
            self.encoder = keras.Sequential([
                keras.layers.Conv1D(self.n_hidden, kernel_size=3, padding="same", activation=tf.nn.tanh, name="encoder/conv1"),
                keras.layers.Conv1D(self.n_hidden, kernel_size=3, padding="same", activation=tf.nn.tanh, name="encoder/conv2"),
            ], name="encoder")
        elif self.encoder_type == "gru":
            self.encoder = keras.layers.GRU(self.n_hidden, return_sequences=True, name="encoder")
        elif self.encoder_type == "grud":
            self.encoder = GRUD(self.n_hidden, return_sequences=True, x_imputation="decay", feed_masking=True, name="encoder")
        else:
            raise ValueError(f"Unknown encoder type: {self.encoder_type}")

        min_softplus = lambda x: self.min_latent_sigma + (1 - self.min_latent_sigma) * tf.nn.softplus(x)
        self.encoder_mu    = keras.layers.Dense(self.z_dim, activation=None,         name="encoder/mu")
        self.encoder_sigma = keras.layers.Dense(self.z_dim, activation=min_softplus, name="encoder/sigma")

        # === Decoder ===

        if self.decoder_type == "mlp":
            self.decoder = keras.Sequential([
                keras.layers.Dense(self.n_hidden, activation=tf.nn.tanh, name="decoder/dense1"),
                keras.layers.Dense(self.n_hidden, activation=tf.nn.tanh, name="decoder/dense2"),
            ], name="decoder")
        elif self.decoder_type == "cnn":
            self.decoder = keras.Sequential([
                keras.layers.Conv1D(self.n_hidden, kernel_size=3, padding="same", activation=tf.nn.tanh, name="decoder/conv1"),
                keras.layers.Conv1D(self.n_hidden, kernel_size=3, padding="same", activation=tf.nn.tanh, name="decoder/conv2"),
            ], name="decoder")
        elif self.decoder_type == "gru":
            self.decoder = keras.layers.GRU(self.n_hidden, return_sequences=True, name="decoder")
        elif self.decoder_type == "grud":
            self.decoder = GRUD(
                self.n_hidden, return_sequences=True, input_decay=None, hidden_decay="exp_relu", feed_masking=False,
                name="decoder", x_imputation="raw",
            )
        else:
            raise ValueError(f"Unknown decoder type: {self.decoder_type}")

        min_softplus = lambda x: self.min_sigma + (1 - self.min_sigma) * tf.nn.softplus(x)
        self.decoder_mu    = keras.layers.Dense(self.x_dim, activation=None,         name="decoder/mu")
        self.decoder_sigma = keras.layers.Dense(self.x_dim, activation=min_softplus, name="decoder/sigma")

        # === Classifier ===

        if self.classifier_type == "gru":
            self.classifier_gru = keras.layers.GRU(
                self.n_units, return_sequences=self.return_sequences, dropout=self.dropout, recurrent_dropout=self.recurrent_dropout,
                name="classifier/gru",
            )
        elif self.classifier_type == "grud":
            self.classifier_gru = GRUD(
                self.n_units, return_sequences=self.return_sequences, dropout=self.dropout, recurrent_dropout=self.recurrent_dropout,
                name="classifier/grud", x_imputation="raw",
            )
        else:
            raise ValueError(f"Unknown classifier type: {self.classifier_type}")

        self.classifier_dense = keras.layers.Dense(self.output_dims, activation=self.output_activation, name="classifier/dense")

        # === Misc ===

        self.initial_encoder = keras.Sequential([
            keras.layers.Dense(self.n_hidden, activation=tf.nn.tanh, name="initial_encoder/dense1"),
            keras.layers.Dense(self.n_units,  activation=tf.nn.tanh, name="initial_encoder/dense2"),
        ], name="initial_encoder")

        if self.impute_type.endswith("-decay"):
            self.interpolator = DecayInterpolate(name="interpolator")

        if self.prior_type == "autoregressive":
            self.prior_gru = keras.layers.GRU(self.z_dim,   return_sequences=True,   name="prior/gru")
            self.prior_mu  = keras.layers.Dense(self.z_dim, activation=None,         name="prior/mu")
            self.prior_std = keras.layers.Dense(self.z_dim, activation=min_softplus, name="prior/std")

        if self.missing_type == "mnar":
            self.mnar_encoder = keras.Sequential([
                keras.layers.Dense(self.n_hidden, activation=tf.nn.tanh, name="mnar_encoder/dense1"),
                keras.layers.Dense(self.x_dim,    activation=tf.nn.tanh, name="mnar_encoder/dense2"),
            ], name="mnar_encoder")

    def get_config(self):
        return self._config

    def reconstruct_loss(self, x_tilde, x, masks):
        return tf.reduce_sum(abs(x_tilde - x) * masks) / (tf.reduce_sum(masks) + 1e-9)

    def call(self, inputs, output=None, training=False, return_loss=False, return_aux=False):
        # statics:      [n_batch, static_dim]
        # times:        [n_batch, n_times]
        # values:       [n_batch, n_times, x_dim]
        # measurements: [n_batch, n_times, x_dim]
        # lengths:      [n_batch, 1]
        statics, times, values, measurements, lengths = inputs

        if training:
            n_samples = self.n_train_samples
            n_latents = self.n_train_latents
        else:
            n_samples = self.n_test_samples
            n_latents = self.n_test_latents


        # Preprocess
        if len(tf.shape(lengths)) == 2:
            lengths = tf.squeeze(lengths, axis=-1)                                                                      # [n_batch]

        if not training:
            output = None
        elif output is not None and len(tf.shape(output)) == 1:
            output = tf.expand_dims(output, axis=1)                                                                     # [n_batch, 1]

        # NOTE: missing_mask includes padding_mask
        x_obs = values                                                                                                  # [n_batch, n_times, x_dim]
        missing_mask = measurements                                                                                     # [n_batch, n_times, x_dim]
        padding_mask = tf.sequence_mask(lengths)                                                                        # [n_batch, n_times]

        # === VAE ===

        # Encode: g(xᵒ; γ) -> q(z | xᵒ)
        if self.encoder_type != "grud" and self.sparsity_normalize:
            x_obs = F.sparsity_normalize(x_obs, missing_mask, axis=-1, constant=4)                                      # [n_batch, n_times, x_dim]

        q_z = self.encode(times, x_obs, missing_mask, padding_mask)                                                     # [n_batch, n_times, z_dim] x 2

        # Latent: zₖ ~ q(z | xᵒ)
        z_samples = q_z.sample(n_latents)                                                                               # [n_latents, n_batch, n_times, z_dim]

        # Prior: p(z)
        p_z = self.prior(z_samples, padding_mask=padding_mask)                                                          # [] | [n_latents, n_batch, n_times, z_dim]

        # Decode: h(zₖ; θ) -> p(xₖᵐ | zₖ)
        p_x_tilde = self.decode(times, z_samples, padding_mask)                                                         # [n_latents, n_batch, n_times, x_dim]

        # === Impute ===

        # log p(xᵒ | zₖ)
        log_p_x_obs_given_z = tf.reduce_sum(tf.where(                                                                   # [n_latents, n_batch, n_times]
            missing_mask,                                                         # [           n_batch, n_times, x_dim]
            p_x_tilde.log_prob(x_obs),                                            # [n_latents, n_batch, n_times, x_dim]
            0.,
        ), axis=-1)

        # log p(zₖ)
        log_p_z = tf.reduce_sum(tf.where(                                                                               # [n_latents, n_batch, n_times]
            tf.expand_dims(padding_mask, axis=-1),                                # [           n_batch, n_times,     1]
            p_z.log_prob(z_samples),                                              # [n_latents, n_batch, n_times, z_dim]
            0.,
        ), axis=-1)

        # log q(zₖ | xᵒ)
        log_q_z_given_x_obs = tf.reduce_sum(tf.where(                                                                   # [n_latents, n_batch, n_times]
            tf.expand_dims(padding_mask, axis=-1),                                # [           n_batch, n_times,     1]
            q_z.log_prob(z_samples),                                              # [n_latents, n_batch, n_times, z_dim]
            0.,
        ), axis=-1)

        if self.return_sequences:
            log_p_x_obs_given_z = tf.cumsum(log_p_x_obs_given_z, axis=-1)                                               # [n_latents, n_batch, n_times]
            log_p_z             = tf.cumsum(log_p_z, axis=-1)                                                           # [n_latents, n_batch, n_times]
            log_q_z_given_x_obs = tf.cumsum(log_q_z_given_x_obs, axis=-1)                                               # [n_latents, n_batch, n_times]
        else:
            log_p_x_obs_given_z = tf.reduce_sum(log_p_x_obs_given_z, axis=-1, keepdims=True)                            # [n_latents, n_batch, 1]
            log_p_z             = tf.reduce_sum(log_p_z, axis=-1, keepdims=True)                                        # [n_latents, n_batch, 1]
            log_q_z_given_x_obs = tf.reduce_sum(log_q_z_given_x_obs, axis=-1, keepdims=True)                            # [n_latents, n_batch, 1]

        # wₗ = softmax( p(xᵒ | zₖ) p(z) / q(z | xᵒ) )
        log_w_latents = tf.nn.log_softmax(log_p_x_obs_given_z + log_p_z - log_q_z_given_x_obs, axis=0)                  # [n_latents, n_batch, (n_times)]

        # Generate: xₖⱼᵐ ~ p(xₖᵐ | zₖ)
        x_tilde = self.generate(p_x_tilde, log_w_latents, training=training)                                            # [(n_samples), (n_latents), n_batch, n_times, x_dim]

        # Dropout
        drop_mask, log_m = self.generate_observe_dropout_mask(tf.shape(x_tilde), missing_mask, training=training)       # [(n_samples), (n_latents), n_batch, n_times, x_dim], [(n_samples), (n_latents), n_batch]

        # Impute: xₖⱼ' = xᵒ and xₖⱼᵐ
        x_impute = self.impute(times, x_obs, x_tilde, drop_mask, padding_mask)                                          # [(n_samples), (n_latents), n_batch, n_times, x_dim]

        # log p(s | xᵒ, xₖⱼᵐ, mₖⱼ)
        if self.missing_type == "mar":
            log_p_s_given_x = tf.zeros(tf.shape(x_impute)[:-1])                                                         # [(n_samples), (n_latents), n_batch, n_times]
        elif self.missing_type == "mnar":
            logits_miss = self.mnar_encoder(x_impute)                                                                   # [(n_samples), (n_latents), n_batch, n_times, x_dim]
            p_s = Bernoulli(logits=logits_miss)                                                                         # [(n_samples), (n_latents), n_batch, n_times, x_dim]

            log_p_s_given_x = tf.reduce_sum(tf.where(                                                                   # [(n_samples), (n_latents), n_batch, n_times]
                tf.expand_dims(padding_mask, axis=-1),              # [                          n_batch, n_times,     1]
                p_s.log_prob(tf.cast(drop_mask, dtype=tf.float32)), # [(n_samples), (n_latents), n_batch, n_times, x_dim]
                0.,
            ), axis=-1)

        if self.return_sequences:
            log_p_s_given_x = tf.cumsum(log_p_s_given_x, axis=-1)                                                       # [(n_samples), (n_latents), n_batch, n_times]
        else:
            log_p_s_given_x = tf.reduce_sum(log_p_s_given_x, axis=-1, keepdims=True)                                    # [(n_samples), (n_latents), n_batch, 1]

        # Classify: f(xᵒ, xₖⱼᵐ; ϕ) -> p(y | xᵒ, xₖⱼᵐ, mₖⱼ)
        log_p_y = self.classify(statics, times, x_impute, drop_mask, padding_mask, output=output)                       # [(n_samples), (n_latents), n_batch, (n_times), y_dim]

        # rₖⱼ = p(s | xᵒ, xₖⱼᵐ, mₖⱼ) p(xᵒ | zₖ) p(zₖ) p(mₖⱼ) / q(zₖ | xᵒ)
        log_r = tf.expand_dims(log_p_s_given_x + log_p_x_obs_given_z + log_p_z + log_m - log_q_z_given_x_obs, axis=-1)  # [(n_samples), (n_latents), n_batch, (n_times), 1]

        # wₖⱼ = rₖⱼ / ∑ₖⱼ rₖⱼ
        log_w = tf.nn.log_softmax(log_r, axis=1)                                                                        # [(n_samples), (n_latents), n_batch, (n_times), 1]

        # p(y | xᵒ) ≈ Eⱼ[ ∑ₖ wₖⱼ p(y | xᵒ, xₖⱼᵐ, mₖⱼ) ]
        y_logit = tf.reduce_mean(tf.reduce_logsumexp(log_w + log_p_y, axis=1), axis=0)                                  # [n_batch, (n_times), y_dim]
        y_prob = tf.exp(y_logit)                                                                                        # [n_batch, (n_times), y_dim]

        if not self.return_sequences:
            y_prob = tf.squeeze(y_prob, axis=-2)                                                                        # [n_batch, y_dim]

        n_lat = tf.math.log(tf.cast(tf.shape(log_p_y)[1], tf.float32))

        if self.objective == "elbo":
            # R(xᵒ, y) = Eₖⱼ[ log p(y | xᵒ, xₖⱼᵐ, mₖⱼ) p(s | xᵒ, xₖⱼᵐ, mₖⱼ) p(xᵒ | zₖ) p(zₖ) / q(zₖ | xᵒ) ]
            loss = -tf.reduce_mean(log_p_y + log_r)
        else:
            # R(xᵒ, y) = Eⱼ[ log 1 / K ⋅ ∑ₖ p(y | xᵒ, xₖⱼᵐ, mₖⱼ) p(s | xᵒ, xₖⱼᵐ, mₖⱼ) p(xᵒ | zₖ) p(zₖ) / q(zₖ | xᵒ) ]
            loss = -tf.reduce_mean(tf.reduce_logsumexp(log_p_y + log_r, axis=1) - n_lat)

        # ESS = Eⱼ[ 1 / ∑ₖ wₖⱼ² ]
        ess = tf.reduce_mean(1. / tf.exp(tf.reduce_logsumexp(log_w * 2, axis=1)))                                       # []

        # Prediction error = Eⱼ[ 1 / K ⋅ ∑ₖ p(y | xᵒ, xₖⱼᵐ) / q(z | xᵒ)]
        pred_error = -tf.reduce_mean(tf.reduce_logsumexp(log_p_y + tf.expand_dims(log_q_z_given_x_obs, axis=-1), axis=1) - n_lat)

        # Reconstruction error = p(xᵒ | z)
        recon_error = -tf.reduce_mean(log_p_x_obs_given_z)

        # Regularization = p(zₖ) / q(z | xᵒ)
        regular = -tf.reduce_mean(log_p_z - log_q_z_given_x_obs)

        # others
        prefix = "Train" if training else "Valid"
        aux = {
            "log_w": log_w,
            "x_impute": x_impute,
            "x_mu": p_x_tilde.loc,
            "x_sigma": p_x_tilde.scale,
            "metrics": {
                f"{prefix}/ess": ess,
                f"{prefix}/pred_err": pred_error,
                f"{prefix}/recon_err": recon_error,
                f"{prefix}/regular": regular,
            }
        }

        if return_loss and return_aux:
            return y_prob, loss, aux
        elif return_loss:
            return y_prob, loss
        elif return_aux:
            return y_prob, aux
        else:
            return y_prob

    def encode(self, times, values, missing_mask, padding_mask):
        if self.encoder_type == "grud":
            times = tf.expand_dims(times, axis=-1)
            r = self.encoder(GRUDInput(values=values, mask=missing_mask, times=times), mask=padding_mask)               # [n_batch, n_times, n_hidden]
        else:
            r = self.encoder(values)                                                                                    # [n_batch, n_times, n_hidden]

        z_mu = self.encoder_mu(r)                                                                                       # [n_batch, n_times, z_dim]
        z_sigma = self.encoder_sigma(r)                                                                                 # [n_batch, n_times, z_dim]

        q_z = Normal(loc=z_mu, scale=z_sigma)                                                                           # [n_batch, n_times, z_dim]
        return q_z                                                                                                      # [n_batch, n_times, z_dim]

    def decode(self, times, z, padding_mask):
        if self.decoder_type == "mlp":
            h = self.decoder(z)                                                                                         # [n_latents, n_batch, n_times, n_hidden]
        elif self.decoder_type == "cnn":
            h = self.decoder(z)                                                                                         # [n_latents, n_batch, n_times, n_hidden]
        elif self.decoder_type == "gru" or self.decoder_type == "grud":
            shape = tf.shape(z)  # (n_latents, n_batch, n_times, z_dim)

            z = tf.reshape(z, shape=[shape[0] * shape[1], shape[2], shape[3]])                                          # [n_latents x n_batch, n_times, z_dim]
            padding_mask = tf.tile(padding_mask, multiples=[shape[0], 1])                                               # [n_latents x n_batch, n_times]

            if self.decoder_type == "gru":
                h = self.decoder(z, mask=padding_mask)                                                                  # [n_latents x n_batch, n_times, n_hidden]
            elif self.decoder_type == "grud":
                times = tf.expand_dims(tf.tile(times, multiples=[shape[0], 1]), axis=-1)                                # [n_latents x n_batch, n_times, 1]
                h = self.decoder(GRUDInput(values=z, mask=tf.ones_like(z, dtype=bool), times=times), mask=padding_mask) # [n_latents x n_batch, n_times, n_hidden]

            h = tf.reshape(h, shape=[shape[0], shape[1], shape[2], self.n_hidden])                                      # [n_latents, n_batch, n_times, n_hidden]

        x_tilde_mu = self.decoder_mu(h)                                                                                 # [n_latents, n_batch, n_times, x_dim]
        x_tilde_sigma = self.decoder_sigma(h)                                                                           # [n_latents, n_batch, n_times, x_dim]

        p_x_tilde = Normal(loc=x_tilde_mu, scale=x_tilde_sigma)                                                         # [n_latents, n_batch, n_times, x_dim]
        return p_x_tilde                                                                                                # [n_latents, n_batch, n_times, x_dim] x 2

    def prior(self, z_samples, padding_mask):
        if self.prior_type == "standard":
            p_z = Normal(loc=0., scale=1.)                                                                              # []
        elif self.prior_type == "autoregressive":
            shape = tf.shape(z_samples)  # (n_latents, n_batch, n_times, z_dim)
            z_samples = tf.reshape(z_samples, shape=(shape[0] * shape[1], shape[2], shape[3]))                          # [n_latents x n_batch, n_times, z_dim]
            padding_mask = tf.tile(padding_mask, multiples=[shape[0], 1])                                               # [n_latents x n_batch, n_times]

            r = self.prior_gru(z_samples[:, 1:, :], initial_state=z_samples[:, 0, :], mask=padding_mask[:, 1:])         # [n_latents x n_batch, n_times - 1, z_dim]
            r = tf.reshape(r, shape=(shape[0], shape[1], -1, shape[3]))                                                 # [n_latents,  n_batch, n_times - 1, z_dim]

            p_z_mu    = self.prior_mu(r)                                                                                # [n_latents,  n_batch, n_times - 1, z_dim]
            p_z_sigma = self.prior_std(r)                                                                               # [n_latents,  n_batch, n_times - 1, z_dim]

            p_z_mu    = tf.pad(p_z_mu,    [[0, 0], [0, 0], [1, 0], [0, 0]], constant_values=0)                          # [n_latents,  n_batch, n_times, z_dim]
            p_z_sigma = tf.pad(p_z_sigma, [[0, 0], [0, 0], [1, 0], [0, 0]], constant_values=1)                          # [n_latents,  n_batch, n_times, z_dim]

            p_z = Normal(loc=p_z_mu, scale=p_z_sigma)                                                                   # [n_latents,  n_batch, n_times, z_dim]

        return p_z                                                                                                      # [] | [n_latents, n_batch, n_times, z_dim]

    def generate(self, p_x_tilde, log_w_latents, training=False):
        n_samples = self.n_train_samples if training else self.n_test_samples

        # Generate
        if self.impute_type.startswith("mean") or not training:
            # xₖ₁ᵐ = xₖᵐ_μ
            x_tilde = tf.expand_dims(p_x_tilde.loc, axis=0)                                                             # [        1, n_latents, n_batch, n_times, x_dim]
        elif self.impute_type.startswith("single"):
            # x₁ⱼᵐ = ∑ₖ wₖ xₖⱼᵐ, where xₖⱼᵐ ~ p(xₖᵐ | zₖ)
            x_tilde = tf.reduce_sum(                                                                                    # [n_samples,         1, n_batch, n_times, x_dim]
                tf.expand_dims(tf.exp(log_w_latents), axis=-1)         # [           n_latents, n_batch, (n_times),   1]
                * p_x_tilde.sample(n_samples),                         # [n_samples, n_latents, n_batch, n_times, x_dim]
                axis=1, keepdims=True,
            )
        elif self.impute_type.startswith("multiple"):
            # xₖⱼᵐ ~ p(xₖᵐ | zₖ)
            x_tilde = p_x_tilde.sample(n_samples)                                                                       # [n_samples, n_latents, n_batch, n_times, x_dim]
        else:
            raise NotImplementedError(f"Not supported impute type: {self.impute_type}")

        return x_tilde                                                                                                  # [(n_samples), (n_latents), n_batch, n_times, x_dim]

    def generate_observe_dropout_mask(self, shape, missing_mask, training=False):
        if tf.reduce_any(self.observe_dropout > 0.) and training:
            p_m = Bernoulli(probs=(1. - self.observe_dropout))                                                          # [] | [x_dim]
            drop_mask = p_m.sample(shape if self.observe_dropout.ndim == 0 else shape[:-1])                             # [(n_samples), (n_latents), n_batch, n_times, x_dim]

            log_m = tf.reduce_sum(tf.where(                                                                             # [(n_samples), (n_latents), n_batch, n_times]
                missing_mask,                                      # [                          n_batch, n_times, x_dim]
                p_m.log_prob(drop_mask),                           # [(n_samples), (n_latents), n_batch, n_times, x_dim]
                0.,
            ), axis=-1)

            drop_mask = tf.cast(drop_mask, dtype=bool) & missing_mask                                                   # [(n_samples), (n_latents), n_batch, n_times, x_dim]

        else:
            drop_mask = tf.broadcast_to(missing_mask, shape)                                                            # [(n_samples), (n_latents), n_batch, n_times, x_dim]

            log_m = tf.zeros(shape[:4])                                                                                 # [(n_samples), (n_latents), n_batch, n_times]

        if self.return_sequences:
            log_m = tf.cumsum(log_m, axis=-1)                                                                           # [(n_samples), (n_latents), n_batch, n_times]
        else:
            log_m = tf.reduce_sum(log_m, axis=-1, keepdims=True)                                                        # [(n_samples), (n_latents), n_batch, 1]

        return drop_mask, log_m                                                                                         # [(n_samples), (n_latents), n_batch, n_times, x_dim], [(n_samples), (n_latents), n_batch, (n_times)]

    def impute(self, times, x_obs, x_tilde, missing_mask, padding_mask):
        # Interpolate obs and combine with generated missing
        if self.impute_type.endswith("-decay"):
            # xₖⱼ' = xᵒ if x is observed else xₖⱼᵐ ~ p(xᵐ | z)
            x_comb = tf.where(                                                                                          # [(n_samples), (n_latents), n_batch, n_times, x_dim]
                missing_mask,                                      # [(n_samples), (n_latents), n_batch, n_times, x_dim]
                x_obs,                                             # [                          n_batch, n_times, x_dim]
                x_tilde,                                           # [(n_samples), (n_latents), n_batch, n_times, x_dim]
            )

            shape = tf.shape(x_comb)  # ((n_samples), (n_latents), n_batch, n_times, x_dim)
            n_tiles = shape[0] * shape[1]

            times = tf.expand_dims(tf.tile(times, multiples=[n_tiles, 1]), axis=-1)                                     # [(n_samples) x (n_latents) x n_batch, n_times, 1]
            x_comb = tf.reshape(x_comb, shape=[n_tiles * shape[2], shape[3], shape[4]])                                 # [(n_samples) x (n_latents) x n_batch, n_times, x_dim]
            missing_mask = tf.reshape(missing_mask, shape=[n_tiles * shape[2], shape[3], shape[4]])                     # [(n_samples) x (n_latents) x n_batch, n_times, x_dim]
            padding_mask = tf.tile(padding_mask, multiples=[n_tiles, 1])                                                # [(n_samples) x (n_latents) x n_batch, n_times]

            x_impute = self.interpolator(                                                                               # [(n_samples) x (n_latents) x n_batch, n_times, x_dim]
                InterpolateInput(values=x_comb, mask=missing_mask, times=times),
                mask=padding_mask,
            )

            x_impute = tf.reshape(x_impute, shape=[shape[0], shape[1], shape[2], shape[3], shape[4]])                   # [(n_samples), (n_latents), n_batch, n_times, x_dim]

        else:
            # xₖⱼ' = xᵒ if x is observed else xₖⱼᵐ ~ p(xᵐ | z)
            x_impute = tf.where(                                                                                        # [(n_samples), (n_latents), n_batch, n_times, x_dim]
                missing_mask,                                      # [                          n_batch, n_times, x_dim]
                x_obs,                                             # [                          n_batch, n_times, x_dim]
                x_tilde,                                           # [(n_samples), (n_latents), n_batch, n_times, x_dim]
            )

        return x_impute                                                                                                 # [(n_samples), (n_latents), n_batch, n_times, x_dim]

    def classify(self, statics, times, x_impute, impute_mask, padding_mask, output=None):
        shape = tf.shape(x_impute)  # ((n_samples), (n_latents), n_batch, n_times, x_dim)
        n_tiles = shape[0] * shape[1]
        x_impute = tf.reshape(x_impute, shape=[n_tiles * shape[2], shape[3], shape[4]])                                 # [(n_samples) x (n_latents) x n_batch, n_times, x_dim]
        padding_mask = tf.tile(padding_mask, multiples=[n_tiles, 1])                                                    # [(n_samples) x (n_latents) x n_batch, n_times]

        initial_state = tf.tile(self.initial_encoder(statics), multiples=[n_tiles, 1])                                  # [(n_samples) x (n_latents) x n_batch, state_dim]

        if self.classifier_type == "gru":
            v = self.classifier_gru(x_impute, mask=padding_mask, initial_state=initial_state)                           # [(n_samples) x (n_latents) x n_batch, (n_times,) y_dim]

        elif self.classifier_type == "grud":
            impute_mask = tf.reshape(impute_mask, shape=[shape[0] * shape[1] * shape[2], shape[3], shape[4]])           # [(n_samples) x (n_latents) x n_batch, n_times, x_dim]
            times = tf.expand_dims(tf.tile(times, multiples=[n_tiles, 1]), axis=-1)                                     # [(n_samples) x (n_latents) x n_batch, n_times, 1]

            grud_initial_state = GRUDState(
                h=initial_state,                                                                                        # [(n_samples) x (n_latents) x n_batch, state_dim]
                x_keep=tf.zeros((shape[0] * shape[1] * shape[2], shape[4])),                                            # [(n_samples) x (n_latents) x n_batch, x_dim]
                s_prev=tf.tile(times[:, 0, :], [1, shape[4]]),                                                          # [(n_samples) x (n_latents) x n_batch, x_dim]
            )

            v = self.classifier_gru(                                                                                    # [(n_samples) x (n_latents) x n_batch, (n_times,) y_dim]
                GRUDInput(values=x_impute, mask=impute_mask, times=times), mask=padding_mask,
                initial_state=grud_initial_state,
            )

        probs = self.classifier_dense(v)                                                                                # [(n_samples) x (n_latents) x n_batch, (n_times,) y_dim]
        p_y = Bernoulli(probs=probs)                                                                                    # [(n_samples) x (n_latents) x n_batch, (n_times,) y_dim]

        if self.return_sequences:
            if output is None:
                labels = tf.ones_like(probs)                                                                            # [(n_samples) x (n_latents) x n_batch, n_times, y_dim]
            else:
                labels = tf.cast(tf.tile(output, multiples=[n_tiles, 1, 1]), dtype=tf.float32)                          # [(n_samples) x (n_latents) x n_batch, n_times, y_dim]

            log_p_y = p_y.log_prob(labels)                                                                              # [(n_samples) x (n_latents) x n_batch, n_times, y_dim]
            log_p_y = tf.reshape(log_p_y, shape=(shape[0], shape[1], shape[2], shape[3], self.output_dims))             # [(n_samples),  (n_latents),  n_batch, n_times, y_dim]

        else:
            if output is None:
                labels = tf.ones_like(probs)                                                                            # [(n_samples) x (n_latents) x n_batch, y_dim]
            else:
                labels = tf.cast(tf.tile(output, multiples=[n_tiles, 1]), dtype=tf.float32)                             # [(n_samples) x (n_latents) x n_batch, y_dim]

            log_p_y = p_y.log_prob(labels)                                                                              # [(n_samples) x (n_latents) x n_batch, y_dim]
            log_p_y = tf.reshape(log_p_y, shape=(shape[0], shape[1], shape[2], 1, self.output_dims))                    # [(n_samples),  (n_latents),  n_batch, 1, y_dim]

        return log_p_y                                                                                                  # [(n_samples), (n_latents), n_batch, (n_times), y_dim]

    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            _, loss, aux = self(x, output=y, training=True, return_loss=True, return_aux=True)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        return {"loss": loss, **aux.get("metrics", {})}

    def test_step(self, data):
        x, y = data
        _, loss, aux = self(x, training=False, return_loss=True, return_aux=True)

        return {"loss": loss, **aux.get("metrics", {})}
