from typing import Any, Optional, Sequence

import distrax
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp


def default_init(scale=1.0):
    """Default kernel initializer."""
    return nn.initializers.variance_scaling(scale, 'fan_avg', 'uniform')


def ensemblize(cls, num_qs, in_axes=None, out_axes=0, **kwargs):
    """Ensemblize a module."""
    return nn.vmap(
        cls,
        variable_axes={'params': 0, 'intermediates': 0},
        split_rngs={'params': True},
        in_axes=in_axes,
        out_axes=out_axes,
        axis_size=num_qs,
        **kwargs,
    )


class Identity(nn.Module):
    """Identity layer."""

    def __call__(self, x):
        return x


class MLP(nn.Module):
    """Multi-layer perceptron (MLP).

    Attributes:
        hidden_dims: Hidden layer dimensions.
        activations: Activation function.
        activate_final: Whether to apply activation to the final layer.
        kernel_init: Kernel initializer.
        layer_norm: Whether to apply layer normalization.
    """

    hidden_dims: Sequence[int]
    activations: Any = nn.gelu
    activate_final: bool = False
    kernel_init: Any = default_init()
    layer_norm: bool = False

    @nn.compact
    def __call__(self, x):
        for i, size in enumerate(self.hidden_dims):
            x = nn.Dense(size, kernel_init=self.kernel_init)(x)
            if i + 1 < len(self.hidden_dims) or self.activate_final:
                x = self.activations(x)
                if self.layer_norm:
                    x = nn.LayerNorm()(x)
            if i == len(self.hidden_dims) - 2:
                self.sow('intermediates', 'feature', x)
        return x


class ResMLP(nn.Module):
    """Residual MLP.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        activations: Activation function.
        activate_final: If True, it works as an intermediate layer; if False, it works as a standalone neural network.
        kernel_init: Kernel initializer.
        layer_norm: Whether to apply layer normalization.
    """

    hidden_dims: Sequence[int]
    activations: Any = nn.gelu
    activate_final: bool = False
    kernel_init: Any = default_init()
    layer_norm: bool = True

    @nn.compact
    def __call__(self, x):
        assert self.layer_norm

        x = nn.Dense(self.hidden_dims[0], kernel_init=self.kernel_init)(x)
        x = nn.LayerNorm()(x)
        x = self.activations(x)
        num_res_blocks = len(self.hidden_dims) if self.activate_final else len(self.hidden_dims) - 1

        for i in range(num_res_blocks):
            size = self.hidden_dims[i]
            residual = x
            x = nn.Dense(size, kernel_init=self.kernel_init)(x)
            x = nn.LayerNorm()(x)
            x = self.activations(x)
            x = nn.Dense(size, kernel_init=self.kernel_init)(x)
            x = nn.LayerNorm()(x)
            x = x + residual
        x = nn.LayerNorm()(x)

        if not self.activate_final:
            x = nn.Dense(self.hidden_dims[-1], kernel_init=self.kernel_init)(x)

        return x


class LengthNormalize(nn.Module):
    """Length normalization layer.

    It normalizes the input along the last dimension to have a length of sqrt(dim).
    """

    @nn.compact
    def __call__(self, x):
        return x / jnp.linalg.norm(x, axis=-1, keepdims=True) * jnp.sqrt(x.shape[-1])


class Param(nn.Module):
    """Scalar parameter module."""

    init_value: float = 0.0

    @nn.compact
    def __call__(self):
        return self.param('value', init_fn=lambda key: jnp.full((), self.init_value))


class LogParam(nn.Module):
    """Scalar parameter module with log scale."""

    init_value: float = 1.0

    @nn.compact
    def __call__(self):
        log_value = self.param('log_value', init_fn=lambda key: jnp.full((), jnp.log(self.init_value)))
        return jnp.exp(log_value)


class TransformedWithMode(distrax.Transformed):
    """Transformed distribution with mode calculation."""

    def mode(self):
        return self.bijector.forward(self.distribution.mode())


class RunningMeanStd(flax.struct.PyTreeNode):
    """Running mean and standard deviation.

    Attributes:
        eps: Epsilon value to avoid division by zero.
        mean: Running mean.
        std: Running standard deviation.
        alpha: Smoothing factor for the running mean and variance.
    """

    eps: Any = 1e-6
    mean: Any = 0.0
    std: Any = 1.0
    alpha: Any = 0.01

    def normalize(self, batch, normalize_mean=True):
        if normalize_mean:
            batch = (batch - self.mean) / (self.std + self.eps)
        else:
            batch = batch / (self.std + self.eps)

        return batch

    def update(self, mean, std):
        new_mean = self.alpha * mean + (1 - self.alpha) * self.mean
        new_std = self.alpha * std + (1 - self.alpha) * self.std

        return self.replace(mean=new_mean, std=new_std)

class GCDeterActor(nn.Module):
    """Goal-conditioned deterministic actor.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        action_dim: Action dimension.
        mlp_class: MLP class.
        log_std_min: Minimum value of log standard deviation.
        log_std_max: Maximum value of log standard deviation.
        tanh_squash: Whether to squash the action distribution with tanh.
        mean_tanh_squash: Whether to squash the mean net with tanh (the action distribution can still be unbounded).
        state_dependent_std: Whether to use state-dependent standard deviation.
        const_std: Whether to use constant standard deviation.
        final_fc_init_scale: Initial scale of the final fully-connected layer.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    action_dim: int
    mlp_class: Any = MLP
    layer_norm: bool = False
    tanh_squash: bool = False
    final_fc_init_scale: float = 1e-2
    gc_encoder: nn.Module = None

    def setup(self):
        self.actor_net = self.mlp_class(self.hidden_dims, activate_final=True, layer_norm=self.layer_norm)
        self.mean_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))

    def __call__(
        self,
        observations,
        goals=None,
        goal_encoded=False,
        temperature=1.0,
    ):
        """Return the action distribution.

        Args:
            observations: Observations.
            goals: Goals (optional).
            goal_encoded: Whether the goals are already encoded.
            temperature: Scaling factor for the standard deviation.
        """
        if self.gc_encoder is not None:
            inputs = self.gc_encoder(observations, goals, goal_encoded=goal_encoded)
        else:
            inputs = [observations]
            if goals is not None:
                inputs.append(goals)
            inputs = jnp.concatenate(inputs, axis=-1)
        outputs = self.actor_net(inputs)
        means = self.mean_net(outputs)
        if self.tanh_squash:
            means = jnp.tanh(means)
        return means

class GCActor(nn.Module):
    """Goal-conditioned actor.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        action_dim: Action dimension.
        mlp_class: MLP class.
        log_std_min: Minimum value of log standard deviation.
        log_std_max: Maximum value of log standard deviation.
        tanh_squash: Whether to squash the action distribution with tanh.
        mean_tanh_squash: Whether to squash the mean net with tanh (the action distribution can still be unbounded).
        state_dependent_std: Whether to use state-dependent standard deviation.
        const_std: Whether to use constant standard deviation.
        final_fc_init_scale: Initial scale of the final fully-connected layer.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    action_dim: int
    mlp_class: Any = MLP
    layer_norm: bool = False
    log_std_min: Optional[float] = -5
    log_std_max: Optional[float] = 2
    tanh_squash: bool = False
    mean_tanh_squash: bool = False
    state_dependent_std: bool = False
    const_std: bool = True
    final_fc_init_scale: float = 1e-2
    gc_encoder: nn.Module = None

    def setup(self):
        self.actor_net = self.mlp_class(self.hidden_dims, activate_final=True, layer_norm=self.layer_norm)
        self.mean_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))
        if self.state_dependent_std:
            self.log_std_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))
        else:
            if not self.const_std:
                self.log_stds = self.param('log_stds', nn.initializers.zeros, (self.action_dim,))

    def __call__(
        self,
        observations,
        goals=None,
        goal_encoded=False,
        temperature=1.0,
    ):
        """Return the action distribution.

        Args:
            observations: Observations.
            goals: Goals (optional).
            goal_encoded: Whether the goals are already encoded.
            temperature: Scaling factor for the standard deviation.
        """
        if self.gc_encoder is not None:
            inputs = self.gc_encoder(observations, goals, goal_encoded=goal_encoded)
        else:
            inputs = [observations]
            if goals is not None:
                inputs.append(goals)
            inputs = jnp.concatenate(inputs, axis=-1)
        outputs = self.actor_net(inputs)

        means = self.mean_net(outputs)
        if self.mean_tanh_squash:
            means = jnp.tanh(means)
        if self.state_dependent_std:
            log_stds = self.log_std_net(outputs)
        else:
            if self.const_std:
                log_stds = jnp.zeros_like(means)
            else:
                log_stds = self.log_stds

        log_stds = jnp.clip(log_stds, self.log_std_min, self.log_std_max)

        distribution = distrax.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds) * temperature)
        if self.tanh_squash:
            distribution = TransformedWithMode(distribution, distrax.Block(distrax.Tanh(), ndims=1))

        return distribution


class GCDiscreteActor(nn.Module):
    """Goal-conditioned actor for discrete actions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        action_dim: Action dimension.
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
        final_fc_init_scale: Initial scale of the final fully-connected layer.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    action_dim: int
    mlp_class: Any = MLP
    layer_norm: bool = False
    final_fc_init_scale: float = 1e-2
    gc_encoder: nn.Module = None

    def setup(self):
        self.actor_net = self.mlp_class(self.hidden_dims, activate_final=True, layer_norm=self.layer_norm)
        self.logit_net = nn.Dense(self.action_dim, kernel_init=default_init(self.final_fc_init_scale))

    def __call__(
        self,
        observations,
        goals=None,
        goal_encoded=False,
        temperature=1.0,
    ):
        """Return the action distribution.

        Args:
            observations: Observations.
            goals: Goals (optional).
            goal_encoded: Whether the goals are already encoded.
            temperature: Inverse scaling factor for the logits (set to 0 to get the argmax).
        """
        if self.gc_encoder is not None:
            inputs = self.gc_encoder(observations, goals, goal_encoded=goal_encoded)
        else:
            inputs = [observations]
            if goals is not None:
                inputs.append(goals)
            inputs = jnp.concatenate(inputs, axis=-1)
        outputs = self.actor_net(inputs)

        logits = self.logit_net(outputs)

        distribution = distrax.Categorical(logits=logits / jnp.maximum(1e-6, temperature))

        return distribution


class GCValue(nn.Module):
    """Goal-conditioned value/critic function.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
        num_ensembles: Number of ensemble components.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    output_dim: int = None
    mlp_class: Any = MLP
    layer_norm: bool = True
    num_ensembles: int = 2
    encoder: nn.Module = None

    def setup(self):
        mlp_class = self.mlp_class
        if self.num_ensembles > 1:
            mlp_class = ensemblize(mlp_class, self.num_ensembles)
        output_dim = self.output_dim if self.output_dim is not None else 1
        value_net = mlp_class((*self.hidden_dims, output_dim), activate_final=False, layer_norm=self.layer_norm)

        self.value_net = value_net

    def __call__(self, observations, goals=None, actions=None):
        """Return the value/critic function.

        Args:
            observations: Observations.
            goals: Goals (optional).
            actions: Actions (optional).
        """

        if self.encoder is not None:
            if goals is None:
                inputs = [self.encoder(observations)]
            else:
                inputs = [self.encoder(observations), goals]
        else:
            inputs = [observations]
            if goals is not None:
                inputs.append(goals)
        if actions is not None:
            inputs.append(actions)
        inputs = jnp.concatenate(inputs, axis=-1)

        v = self.value_net(inputs)
        if self.output_dim is None:
            v = v.squeeze(-1)

        return v


class GCDiscreteCritic(GCValue):
    """Goal-conditioned critic for discrete actions."""

    action_dim: int = None

    def __call__(self, observations, goals=None, actions=None):
        actions = jnp.eye(self.action_dim)[actions]
        return super().__call__(observations, goals, actions)


class GCBilinearValue(nn.Module):
    """Goal-conditioned bilinear value/critic function.

    This module computes the value function as V(s, g) = phi(s)^T psi(g) / sqrt(d) or the critic function as
    Q(s, a, g) = phi(s, a)^T psi(g) / sqrt(d), where phi and psi output d-dimensional vectors.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        latent_dim: Latent dimension.
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
        num_ensembles: Number of ensemble components.
        value_exp: Whether to exponentiate the value. Useful for contrastive learning.
        state_encoder: Optional state encoder.
        goal_encoder: Optional goal encoder.
    """

    hidden_dims: Sequence[int]
    latent_dim: int
    mlp_class: Any = MLP
    layer_norm: bool = True
    num_ensembles: int = 2
    value_exp: bool = False
    state_encoder: nn.Module = None
    goal_encoder: nn.Module = None

    def setup(self) -> None:
        mlp_class = self.mlp_class
        if self.num_ensembles > 1:
            mlp_class = ensemblize(mlp_class, self.num_ensembles)

        self.phi = mlp_class((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)
        self.psi = mlp_class((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)

    def __call__(self, observations, goals, actions=None, info=False):
        """Return the value/critic function.

        Args:
            observations: Observations.
            goals: Goals.
            actions: Actions (optional).
            info: Whether to additionally return the representations phi and psi.
        """
        if self.state_encoder is not None:
            observations = self.state_encoder(observations)
        if self.goal_encoder is not None:
            goals = self.goal_encoder(goals)

        if actions is None:
            phi_inputs = observations
        else:
            phi_inputs = jnp.concatenate([observations, actions], axis=-1)

        phi = self.phi(phi_inputs)
        psi = self.psi(goals)

        v = (phi * psi / jnp.sqrt(self.latent_dim)).sum(axis=-1)

        if self.value_exp:
            v = jnp.exp(v)

        if info:
            return v, phi, psi
        else:
            return v


class GCDiscreteBilinearCritic(GCBilinearValue):
    """Goal-conditioned bilinear critic for discrete actions."""

    action_dim: int = None

    def __call__(self, observations, goals=None, actions=None, info=False):
        actions = jnp.eye(self.action_dim)[actions]
        return super().__call__(observations, goals, actions, info)


class GCMRNValue(nn.Module):
    """Metric residual network (MRN) value function.

    This module computes the value function as the sum of a symmetric Euclidean distance and an asymmetric
    L^infinity-based quasimetric.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        latent_dim: Latent dimension.
        layer_norm: Whether to apply layer normalization.
        encoder: Optional state/goal encoder.
    """

    hidden_dims: Sequence[int]
    latent_dim: int
    layer_norm: bool = True
    encoder: nn.Module = None

    def setup(self) -> None:
        self.phi = MLP((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)

    def __call__(self, observations, goals, is_phi=False, info=False):
        """Return the MRN value function.

        Args:
            observations: Observations.
            goals: Goals.
            is_phi: Whether the inputs are already encoded by phi.
            info: Whether to additionally return the representations phi_s and phi_g.
        """
        if is_phi:
            phi_s = observations
            phi_g = goals
        else:
            if self.encoder is not None:
                observations = self.encoder(observations)
                goals = self.encoder(goals)
            phi_s = self.phi(observations)
            phi_g = self.phi(goals)

        sym_s = phi_s[..., : self.latent_dim // 2]
        sym_g = phi_g[..., : self.latent_dim // 2]
        asym_s = phi_s[..., self.latent_dim // 2 :]
        asym_g = phi_g[..., self.latent_dim // 2 :]
        squared_dist = ((sym_s - sym_g) ** 2).sum(axis=-1)
        quasi = jax.nn.relu((asym_s - asym_g).max(axis=-1))
        v = jnp.sqrt(jnp.maximum(squared_dist, 1e-12)) + quasi

        if info:
            return v, phi_s, phi_g
        else:
            return v


class GCIQEValue(nn.Module):
    """Interval quasimetric embedding (IQE) value function.

    This module computes the value function as an IQE-based quasimetric.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        latent_dim: Latent dimension.
        dim_per_component: Dimension of each component in IQE (i.e., number of intervals in each group).
        layer_norm: Whether to apply layer normalization.
        encoder: Optional state/goal encoder.
    """

    hidden_dims: Sequence[int]
    latent_dim: int
    dim_per_component: int
    layer_norm: bool = True
    encoder: nn.Module = None

    def setup(self) -> None:
        self.phi = MLP((*self.hidden_dims, self.latent_dim), activate_final=False, layer_norm=self.layer_norm)
        self.alpha = Param()

    def __call__(self, observations, goals, is_phi=False, info=False):
        """Return the IQE value function.

        Args:
            observations: Observations.
            goals: Goals.
            is_phi: Whether the inputs are already encoded by phi.
            info: Whether to additionally return the representations phi_s and phi_g.
        """
        alpha = jax.nn.sigmoid(self.alpha())
        if is_phi:
            phi_s = observations
            phi_g = goals
        else:
            if self.encoder is not None:
                observations = self.encoder(observations)
                goals = self.encoder(goals)
            phi_s = self.phi(observations)
            phi_g = self.phi(goals)

        x = jnp.reshape(phi_s, (*phi_s.shape[:-1], -1, self.dim_per_component))
        y = jnp.reshape(phi_g, (*phi_g.shape[:-1], -1, self.dim_per_component))
        valid = x < y
        xy = jnp.concatenate(jnp.broadcast_arrays(x, y), axis=-1)
        ixy = xy.argsort(axis=-1)
        sxy = jnp.take_along_axis(xy, ixy, axis=-1)
        neg_inc_copies = jnp.take_along_axis(valid, ixy % self.dim_per_component, axis=-1) * jnp.where(
            ixy < self.dim_per_component, -1, 1
        )
        neg_inp_copies = jnp.cumsum(neg_inc_copies, axis=-1)
        neg_f = -1.0 * (neg_inp_copies < 0)
        neg_incf = jnp.concatenate([neg_f[..., :1], neg_f[..., 1:] - neg_f[..., :-1]], axis=-1)
        components = (sxy * neg_incf).sum(axis=-1)
        v = alpha * components.mean(axis=-1) + (1 - alpha) * components.max(axis=-1)

        if info:
            return v, phi_s, phi_g
        else:
            return v


class ActorVectorField(nn.Module):
    """Actor vector field for flow policies.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        action_dim: Action dimension.
        mlp_class: MLP class.
        activate_final: Whether to apply activation to the final layer.
        layer_norm: Whether to apply layer normalization.
        gc_encoder: Optional GCEncoder module to encode the inputs.
    """

    hidden_dims: Sequence[int]
    action_dim: int
    mlp_class: Any = MLP
    activate_final: bool = False
    layer_norm: bool = False
    encoder: nn.Module = None

    def setup(self) -> None:
        self.mlp = self.mlp_class(
            (*self.hidden_dims, self.action_dim), activate_final=False, layer_norm=self.layer_norm
        )

    @nn.compact
    def __call__(self, observations, goals=None, actions=None, times=None, is_encoded=False):
        """Return the current vector.

        Args:
            observations: Observations.
            goals: Goals (optional).
            actions: Current actions.
            times: Current times (optional).
            is_encoded: Whether the inputs are already encoded.
        """
        if not is_encoded and self.encoder is not None:
            if goals is None:
                inputs = self.encoder(observations)
            else:
                inputs = jnp.concatenate([self.encoder(observations), goals], axis=-1)
        else:
            if goals is None:
                inputs = observations
            else:
                inputs = jnp.concatenate([observations, goals], axis=-1)
        if times is None:
            inputs = jnp.concatenate([inputs, actions], axis=-1)
        else:
            inputs = jnp.concatenate([inputs, actions, times], axis=-1)

        v = self.mlp(inputs)

        return v

class ConvBlock(nn.Module):
    ch: int

    @nn.compact
    def __call__(self, x):
        x = nn.Conv(self.ch, (3,3), padding='SAME')(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        residual = x
        
        x = nn.Conv(self.ch, (3,3), padding='SAME')(x)
        x = nn.LayerNorm()(x)
        x = nn.relu(x)
        x = nn.Conv(self.ch, (3,3), padding='SAME')(x)
        x = nn.LayerNorm()(x)

        return residual + x

class PixelDynamics(nn.Module):
    """Dynamics model for pixels.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
        num_ensembles: Number of ensemble components.
    """

    hidden_dims: Sequence[int]
    output_dim: int = None
    layer_norm: bool = True

    def setup(self):
        self.encoder_layers = [ConvBlock(ch) for ch in self.hidden_dims]
        self.decoder_layers = [ConvBlock(ch) for ch in reversed(self.hidden_dims[:-1])]
        self.film_layers = [nn.Sequential([nn.Dense(256), nn.relu, nn.Dense(256), nn.relu, nn.Dense(2*ch)]) for ch in reversed(self.hidden_dims[:-1])] 

        self.final_layer = ConvBlock(self.output_dim)

    def __call__(self, observations, actions):
        """Return the predicted next states.

        Args:
            observations: Observations.
            actions: Actions.
        """

        x = observations
        skips = []
        for encoder_layer in self.encoder_layers[:-1]:
            x = encoder_layer(x)
            skips.append(x)
            x = nn.avg_pool(x, window_shape=(2,2), strides=(2,2), padding='SAME')

        x = self.encoder_layers[-1](x)
        for decoder_layer, film_layer in zip(self.decoder_layers, self.film_layers):
            x = jax.image.resize(x, (*x.shape[:-3], x.shape[-3]*2, x.shape[-2]*2, x.shape[-1]), method='nearest')
            skip = skips.pop()
            x = jnp.concatenate([x, skip], axis=-1)
            x = decoder_layer(x)
            scale, shift = jnp.split(film_layer(actions), 2, axis=-1)
            x = x * scale[..., None, None, :] + shift[..., None, None, :]
            
        x = self.final_layer(x)
        return x

class Dynamics(nn.Module):
    """Dynamics model.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
        num_ensembles: Number of ensemble components.
    """

    hidden_dims: Sequence[int]
    output_dim: int = None
    mlp_class: Any = MLP
    layer_norm: bool = True
    num_ensembles: int = 1
    delta_pred: bool = True
    stochastic: bool = False

    def setup(self):
        mlp_class = self.mlp_class
        if self.num_ensembles > 1:
            mlp_class = ensemblize(mlp_class, self.num_ensembles)
        output_dim = self.output_dim if self.output_dim is not None else 1
        if self.stochastic:
            dynamics_net = mlp_class((*self.hidden_dims, 2 * output_dim), activate_final=False, layer_norm=self.layer_norm)
        else:
            dynamics_net = mlp_class((*self.hidden_dims, output_dim), activate_final=False, layer_norm=self.layer_norm)

        self.dynamics_net = dynamics_net

    def __call__(self, observations, actions):
        """Return the predicted next states.

        Args:
            observations: Observations.
            actions: Actions.
        """
        inputs = []
        inputs.append(observations)
        inputs.append(actions)
        inputs = jnp.concatenate(inputs, axis=-1)

        pred = self.dynamics_net(inputs)
        if self.stochastic:
            means, log_stds = pred[..., :self.output_dim], pred[..., self.output_dim:]
            min_logstd, max_logstd = -5.0, 1.0
            log_stds = jax.nn.sigmoid(log_stds) * (max_logstd - min_logstd) + min_logstd
            if self.delta_pred: means = observations + means
            distribution = distrax.MultivariateNormalDiag(loc=means, scale_diag=jnp.exp(log_stds))
            return distribution

        if self.delta_pred:
            pred = observations + pred

        return pred

class DynamicsVectorField(nn.Module):
    """Dynamics vector field for flow dynamics model.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        ob_dim: Observation dimension.
        mlp_class: MLP class.
        activate_final: Whether to apply activation to the final layer.
        layer_norm: Whether to apply layer normalization.
    """

    hidden_dims: Sequence[int]
    output_dim: int
    mlp_class: Any = MLP
    activate_final: bool = False
    layer_norm: bool = False

    def setup(self) -> None:
        self.mlp = self.mlp_class(
            (*self.hidden_dims, self.output_dim), activate_final=False, layer_norm=self.layer_norm
        )

    @nn.compact
    def __call__(self, observations, actions, next_observations, times=None):
        """Return the current vector.

        Args:
            observations: Observations.
            actions: Actions.
            actions: Current next observations.
            times: Current times (optional).
        """
        inputs = jnp.concatenate([observations, actions], axis=-1)
        if times is None:
            inputs = jnp.concatenate([inputs, next_observations], axis=-1)
        else:
            inputs = jnp.concatenate([inputs, next_observations, times], axis=-1)

        v = self.mlp(inputs)

        return v

class VQDynamics(nn.Module):
    """Dynamics model.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
        num_ensembles: Number of ensemble components.
    """

    hidden_dims: Sequence[int]
    output_dim: int = None
    output_class: int = 32
    mlp_class: Any = MLP
    layer_norm: bool = True
    unimix: float = 0.01
    num_ensembles: int = 1

    def setup(self):
        mlp_class = self.mlp_class
        if self.num_ensembles > 1:
            mlp_class = ensemblize(mlp_class, self.num_ensembles)
        output_dim = self.output_dim 
        dynamics_net = mlp_class((*self.hidden_dims, output_dim), activate_final=False, layer_norm=self.layer_norm)
        self.net = dynamics_net

    def __call__(self, observations, actions):
        """Return the predicted next states.

        Args:
            observations: Observations.
            actions: Actions.
        """
        inputs = []
        inputs.append(observations)
        inputs.append(actions)
        inputs = jnp.concatenate(inputs, axis=-1)

        logit = self.net(inputs)
        prob = jax.nn.softmax(logit.reshape((*logit.shape[:-1], -1, self.output_class)), axis=-1)
        prob = prob * (1 - self.unimix) + (jnp.ones_like(prob) / self.output_class) * self.unimix
        dist = distrax.Categorical(prob)
        return dist

class Codebook(nn.Module):
    num_codebooks: int    # M
    codebook_size: int    # K
    sub_dim: int          # D / M
    use_ema: bool = False
    ema_decay: float = 0.999

    def setup(self):
        # codebooks: [M, K, D / M]
        self.codebooks = self.param(
            "codebooks",
            nn.initializers.orthogonal(),
            (self.num_codebooks, self.codebook_size, self.sub_dim),
        )
        
        if self.use_ema:
            self.ema_dn = self.variable('ema', 'dn', lambda: jnp.zeros((self.num_codebooks, self.codebook_size)))
            self.ema_dw = self.variable('ema', 'dw', lambda: jnp.zeros((self.num_codebooks, self.codebook_size, self.sub_dim)))

    def __call__(self):
        return self.codebooks

    def update_ema(self, inputs, indices):
        # inputs: [B, M, D/M]
        # indices: [B, M] (0~K)
        M = self.num_codebooks
        K = self.codebook_size
        Dm = self.sub_dim

        indices_oh = jax.nn.one_hot(indices, K) # [B, M, K]
        dn = jnp.sum(indices_oh, axis=0)
        dw = jnp.sum(inputs[..., None, :] * indices_oh[..., None], axis=0) # [B, M, D/M], [B, M, K] -> [M, K, D/M]

        ema_dn = self.ema_dn * ema_decay + dn * (1 - ema_decay) 
        ema_dw = self.ema_dw * ema_decay + dw * (1 - ema_decay)

        N = jnp.sum(ema_dn)
        return ema_dw / (ema_dn + 1e-5), ema_dw, ema_dn
        

class _VQEncoder(nn.Module):
    """Encoder model.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
    """

    hidden_dims: Sequence[int]
    output_dim: int = None
    output_class: int = 32
    mlp_class: Any = MLP
    layer_norm: bool = True
    unimix: float = 0.01

    def setup(self):
        mlp_class = self.mlp_class
        net = mlp_class((*self.hidden_dims, self.output_dim), activate_final=False, layer_norm=self.layer_norm)
        self.net = net

    def __call__(self, observations):
        """Return the predicted next states.

        Args:
            observations: Observations.
            actions: Actions.
        """
        inputs = []
        inputs.append(observations)
        inputs = jnp.concatenate(inputs, axis=-1)

        logit = self.net(inputs)
        prob = jax.nn.softmax(logit.reshape((*logit.shape[:-1], -1, self.output_class)), axis=-1)
        prob = prob * (1 - self.unimix) + (jnp.ones_like(prob) / self.output_class) * self.unimix
        dist = distrax.Categorical(prob)
        jax.debug.print('{x}', x=prob[:2,0])
        return dist

class Encoder(nn.Module):
    """Encoder model.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
    """

    hidden_dims: Sequence[int]
    output_dim: int = None
    mlp_class: Any = MLP
    simnorm: bool = True
    layer_norm: bool = True
    encoder: nn.Module = None

    def setup(self):
        mlp_class = self.mlp_class
        net = mlp_class((*self.hidden_dims, self.output_dim), activate_final=False, layer_norm=self.layer_norm)
        self.net = net

    def __call__(self, observations):
        """Return the predicted next states.

        Args:
            observations: Observations.
            actions: Actions.
        """
        inputs = []
        if self.encoder is not None:
            inputs.append(self.encoder(observations))
        else:
            inputs.append(observations)
        inputs = jnp.concatenate(inputs, axis=-1)

        pred = self.net(inputs)
        if self.simnorm:
            pred = self._simnorm(pred)

        return pred

    def _simnorm(self, x, D=8):
        x = x.reshape((*x.shape[:-1], -1, D))
        x = jax.nn.softmax(x, axis=-1)
        x = x.reshape((*x.shape[:-2], -1))
        return x

class PixelDecoder(nn.Module):
    """Dynamics model for pixels.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
        num_ensembles: Number of ensemble components.
    """

    hidden_dims: Sequence[int] = (32, 32, 16)
    output_dim: int = None
    layer_norm: bool = True

    def setup(self):
        H, W, C = self.output_dim
        self.input_layer = nn.Dense(H//8 * W//8 * self.hidden_dims[0])
        self.decoder_layers = [ConvBlock(ch) for ch in self.hidden_dims]

        self.final_layer = ConvBlock(self.output_dim[-1])

    def __call__(self, latents):
        """Return the predicted next states.

        Args:
            observations: Observations.
            actions: Actions.
        """
        H, W, C = self.output_dim; B = latents.shape[:-1]
        #x = self.input_layer(latents)
        x = latents.reshape((*B, H//8, W//8, self.hidden_dims[0]))
        for decoder_layer in self.decoder_layers:
            x = jax.image.resize(x, (*x.shape[:-3], x.shape[-3]*2, x.shape[-2]*2, x.shape[-1]), method='nearest')
            x = decoder_layer(x)
            
        x = self.final_layer(x)
        return x


class Decoder(nn.Module):
    """Decoder model.

    This module can be used for both value V(s, g) and critic Q(s, a, g) functions.

    Attributes:
        hidden_dims: Hidden layer dimensions.
        output_dim: Output dimension (set to None for scalar output).
        mlp_class: MLP class.
        layer_norm: Whether to apply layer normalization.
    """

    hidden_dims: Sequence[int]
    output_dim: int = None
    mlp_class: Any = MLP
    decoder: nn.Module = None

    def setup(self):
        mlp_class = self.mlp_class
        net = mlp_class((*self.hidden_dims, self.output_dim), activate_final=False, layer_norm=False)
        self.net = net

    def __call__(self, observations):
        """Return the predicted next states.

        Args:
            observations: Observations.
            actions: Actions.
        """
        inputs = []
        inputs.append(observations)
        inputs = jnp.concatenate(inputs, axis=-1)

        pred = self.net(inputs)
        if self.decoder is not None:
            pred = self.decoder(pred)

        return pred


