"""EBM Module."""

from typing import Callable

import jax
import jax.numpy as jnp

from flax import linen as nn

Array = jnp.ndarray

from ebm_obj.modules import misc


class SimpleEBM(nn.Module):
  """Energy-based model E(x, z)."""

  image_transform: Callable[[], nn.Module] = misc.Identity
  slot_transform: Callable[[], nn.Module] = misc.Identity
  fuse_transform: Callable[[], nn.Module] = misc.Identity
  output_transform: Callable[[], nn.Module] = misc.Identity

  @nn.compact
  def __call__(self, images: Array, slots: Array, train: bool = False) -> Array:    
    h_x = self.image_transform()(images, train=train)
    h_z = self.slot_transform()(slots, train=train)
    h = self.fuse_transform()([h_x, h_z], train=train)    
    e = self.output_transform()(h, train=train)
    assert len(e.shape) <= 2
    if len(e.shape) == 2:
      assert e.shape[1] == 1
      e = jnp.squeeze(e, 1)
    return e


# class GradEBM(nn.Module):
#   ebm: Callable[[], nn.Module]

#   @nn.compact
#   def __call__(self, xs, zs, train: bool = False):
#     E = self.ebm()
#     xs = jnp.expand_dims(xs, 0)
#     zs = jnp.expand_dims(zs, 0)

#     def _fn(zs):
#       return E(xs, zs).sum()

#     grad_fn = jax.grad(_fn, argnums=0)
#     return jnp.squeeze(grad_fn(zs), axis=0)


# def __test_ebm():
#   from functools import partial
#   batch_size = 3
#   input_shape = [128, 128, 3]
#   slots_shape = [11, 64]
#   rng = jax.random.PRNGKey(42)
#   batch_x = jax.random.uniform(rng, [
#       batch_size,
#   ] + input_shape)
#   batch_z = jax.random.uniform(rng, [
#       batch_size,
#   ] + slots_shape)
#   simple_cnn_cfg = ml_collections.ConfigDict({
#       "features": [64, 64, 64, 64],
#       "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
#       "strides": [(2, 2), (1, 1), (1, 1), (1, 1)],
#       "layer_transpose": [False, False, False, False]
#   })
#   pos_emb_cfg = ml_collections.ConfigDict({
#       "embedding_type": "linear",
#       "update_type": "project_add",
#   })
#   image_transform_cfg = ml_collections.ConfigDict({
#       "backbone": lambda: SimpleCNN(**simple_cnn_cfg),
#       "pos_emb": lambda: misc.PositionEmbedding(**pos_emb_cfg),
#       "reduction": "spatial_flatten"
#   })
#   cross_attention_cfg = ml_collections.ConfigDict({
#       "mlp_dim": 256,
#       "num_heads": 4
#   })
#   fuse_module_cfg = ml_collections.ConfigDict({
#       "attention_block": partial(CrossAttention1DBlock, **cross_attention_cfg),
#       "num_blocks": 2
#   })
#   output_cfg = ml_collections.ConfigDict({"mlp_dims": [128]})
#   cfg = ml_collections.ConfigDict({
#       "image_transform": partial(CNNPosEmbTransform, **image_transform_cfg),
#       "slot_transform": misc.Identity,
#       "fuse_transform": partial(FuseModule, **fuse_module_cfg),
#       "output_transform": partial(EBMOutputModule, **output_cfg)
#   })
#   ebm = SimpleEBM(**cfg)
#   variables = ebm.init({"params": rng}, batch_x, batch_z, train=False)
#   outputs = ebm.apply(variables, batch_x, batch_z, train=False)
#   return outputs


# outputs = __test_ebm()
