"""EBM output transformation lib."""

from typing import Optional, Sequence, Callable, Any, Tuple
import jax
import jax.numpy as jnp

from flax import linen as nn

Array = jnp.ndarray
Shape = Tuple[int]
PRNGKey = Any
Dtype = Any


class EBMOutputModule(nn.Module):
  """EBM output module."""

  mlp_dims: Optional[Sequence[int]] = None
  mlp_kernel_init: Callable[[PRNGKey, Shape, Dtype],
                            Array] = nn.initializers.xavier_uniform()
  mlp_bias_init: Callable[[PRNGKey, Shape, Dtype],
                          Array] = nn.initializers.normal(stddev=1e-6)
  head_kernel_init: Callable[[PRNGKey, Shape, Dtype],
                             Array] = nn.initializers.zeros
  head_bias_init: Callable[[PRNGKey, Shape, Dtype],
                           Array] = nn.initializers.constant(0.0)

  @nn.compact
  def __call__(self, inputs: Array, *, train: bool = False) -> Array:
    x = inputs
    assert len(x.shape) == 3  # (batch, h * w, dim)
    if self.mlp_dims:
      for i in range(len(self.mlp_dims)):
        x = nn.Dense(
            features=self.mlp_dims[i],
            kernel_init=self.mlp_kernel_init,
            bias_init=self.mlp_bias_init,
            name="mlp_{}".format(i))(  # pytype: disable=wrong-arg-types
                x)
        x = nn.gelu(x)
    x = jnp.mean(x, 1)
    x = nn.Dense(
        features=1,
        name="ebm_head",
        kernel_init=self.head_kernel_init,
        bias_init=self.head_bias_init)(
            x)
    return x


# def __test_ebm_output_module():
#   from functools import partial
#   batch_size = 3
#   input_shape = [64 * 64, 64]
#   slots_shape = [16, 64]
#   rng = jax.random.PRNGKey(42)
#   batch_x = jax.random.uniform(rng, [batch_size, ] + input_shape)
#   batch_z = jax.random.uniform(rng, [batch_size, ] + slots_shape)
#   cfg = ml_collections.ConfigDict({
#       "mlp_dims": [128]
#   })
#   output_model = EBMOutputModule(**cfg)
#   variables = output_model.init({"params": rng}, batch_x, train=False)
#   outputs = output_model.apply(variables, batch_x, train=False)
#   return outputs

# outputs = __test_ebm_output_module()
