"""Fuse transformations lib."""

from typing import Callable, Any, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp

from flax import linen as nn

Array = jnp.ndarray
PRNGKey = Any
Shape = Tuple[int]
Dtype = Any


class FuseModule(nn.Module):
  """Fuse information of images and slots."""
  attention_block: Callable[[Optional[str]], nn.Module]
  qkv_features: int = 128
  num_blocks: int = 1

  @nn.compact
  def __call__(self, inputs: Sequence[Array], train: bool = False) -> Array:
    images, slots = inputs
    x = images
    for i in range(self.num_blocks):
      attn_block = self.attention_block(name="attn_block_{}".format(i))      
      x = attn_block(x, slots, deterministic=not train)

    x = nn.LayerNorm(name="fuse_module_norm")(x)
    return x



# def __test__fuse_module():
#   from functools import partial
#   batch_size = 3
#   input_shape = [64 * 64, 64]
#   slots_shape = [16, 64]
#   rng = jax.random.PRNGKey(42)
#   batch_x = jax.random.uniform(rng, [batch_size, ] + input_shape)
#   batch_z = jax.random.uniform(rng, [batch_size, ] + slots_shape)
#   cross_attention_cfg = ml_collections.ConfigDict({
#       "mlp_dim": 256,
#       "num_heads": 4
#   })
#   fuse_module_cfg = ml_collections.ConfigDict({
#       "attention_block" : partial(CrossAttention1DBlock, **cross_attention_cfg),
#       "num_blocks" : 2
#   })
#   fuse_model = FuseModule(**fuse_module_cfg)
#   variables = fuse_model.init({"params": rng}, [batch_x, batch_z], train=False)
#   outputs = fuse_model.apply(variables, [batch_x, batch_z], train=False)
#   return outputs

# outputs = __test__fuse_module()
