import functools
from typing import Any, Callable, Dict, Iterable, Tuple, Mapping, Union

import jax
import jax.numpy as jnp
from flax import linen as nn

Array = jnp.ndarray
ArrayTree = Union[Array, Iterable["ArrayTree"], Mapping[str, "ArrayTree"]]

class ImageModel(nn.Module):
  """Image model."""
   
  sampler: Callable[[], nn.Module]
  decoder: Callable[[], nn.Module]
  
  @nn.compact
  def __call__(self, images: Array, train: bool = False) -> ArrayTree:
    # init latent codes
    all_zs = self.sampler()(images, train=train)
    outputs = self.decoder()(all_zs[-1])
    return {
        "outputs": outputs
    }