"""Convolutional module library."""

import functools
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union

from flax import linen as nn
import jax.numpy as jnp

from ebm_obj.modules import misc

Shape = Tuple[int]

DType = Any
Array = jnp.ndarray
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]]  # pytype: disable=not-supported-yet
ProcessorState = ArrayTree
PRNGKey = Array
NestedDict = Dict[str, Any]


class CNNPosEmbTransform(nn.Module):
  """CNN + Position embedding for image pre-processing."""
  backbone: Callable[[], nn.Module] = misc.Identity
  pos_emb: Callable[[], nn.Module] = misc.Identity
  reduction: Optional[str] = None

  @nn.compact
  def __call__(self, images: Array, train: bool = False) -> Array:
    x = self.backbone()(images, train=train)
    x = self.pos_emb()(x)
    if self.reduction:
      if self.reduction == "spatial_flatten":
        batch_size, height, width, n_features = x.shape
        x = jnp.reshape(x, (batch_size, height * width, n_features))
      else:
        raise ValueError("{} reduction method not found".format(self.reduction))
    return x


# def __test_cnn_transform():
#   batch_size = 3
#   input_shape = [128, 128, 3]
#   rng = jax.random.PRNGKey(42)
#   batch_x = jax.random.uniform(rng, [batch_size, ] + input_shape)

#   simple_cnn_cfg = ml_collections.ConfigDict({
#           "features": [64, 64, 64, 64],
#           "kernel_size": [(5, 5), (5, 5), (5, 5), (5, 5)],
#           "strides": [(2, 2), (1, 1), (1, 1), (1, 1)],
#           "layer_transpose": [False, False, False, False]
#         })
#   pos_emb_cfg =  ml_collections.ConfigDict({
#               "embedding_type": "linear",
#               "update_type": "project_add",
#           })
#   cfg = ml_collections.ConfigDict({
#           "backbone": lambda : SimpleCNN(**simple_cnn_cfg),
#           "pos_emb": lambda : misc.PositionEmbedding(**pos_emb_cfg),
#           "reduction": "spatial_flatten"
#         })
#   cnn_model = CNNPosEmbTransform(**cfg)
#   variables = cnn_model.init({"params": rng}, batch_x, train=False)
#   outputs = cnn_model.apply(variables, batch_x, train=False)
#   return outputs

# outputs = __test_cnn_transform()
