import jax
import jax.numpy as jnp

from flax import linen as nn


import functools
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union

Array = jnp.ndarray

class GaussianStateInit(nn.Module):
  """Random state initialization with zero-mean, unit-variance Gaussian.

  Note: This module does not contain any trainable parameters and requires
    providing a jax.PRNGKey both at training and at test time. Note: This module
    also ignores any conditional input (by design).
  """

  shape: Sequence[int]

  @nn.compact
  def __call__(self, inputs: Optional[Array], batch_size: int,
               train: bool = False) -> Array:
    del inputs, train  # Unused.
    rng = self.make_rng("z_state_init")
    return jax.random.normal(rng, shape=[batch_size] + list(self.shape))

class ULASampler(nn.Module):
  """MCMC sampling z ~ E(z | x)."""
  z_initializer: Callable[[], nn.Module]
  ebm: Callable[[], nn.Module]
  dt: float = 1e-2
  wn:  float = 1.0
  num_steps: int = 5

  @nn.compact
  def __call__(self, images: Array, train: bool = False) -> Array:
    rng_key = self.make_rng("eps")
    ebm = self.ebm()

    def e_fn(xs: Array, zs: Array) -> Array:
      xs = jnp.expand_dims(xs, 0)
      zs = jnp.expand_dims(zs, 0)

      def _fn(zs):
        return ebm(xs, zs, train=train).sum()

      grad_fn = jax.grad(_fn, argnums=0)
      return jnp.squeeze(grad_fn(zs), axis=0)

    grad_e_fn = jax.vmap(e_fn, in_axes=0, out_axes=0)

    def ula_kernel(zs: Array, eps: Array):
      # import pdb;pdb.set_trace()
      grad_zs = grad_e_fn(images, zs)
      # print("grad_zs", grad_zs.shape)
      zs = zs - self.dt * grad_zs + self.wn * jnp.sqrt(2 * self.dt) * eps
      return zs

    zs = self.z_initializer()(images, batch_size=images.shape[0], train=train)
    all_eps = jax.random.normal(rng_key, shape=(self.num_steps,) + zs.shape)
    all_zs = [zs]
    for i in range(self.num_steps):
      zs = ula_kernel(zs, all_eps[i])
      all_zs.append(zs)
    return all_zs

  # @nn.compact
  # def __call__(self, images: Array, train: bool = False) -> Array:
  #   rng_key = self.make_rng("eps")

  #   grad_e_fn = nn.vmap(GradEBM, variable_axes={"params": None}, split_rngs={"params": False}, in_axes=0, out_axes=0)
  #   grad_e = grad_e_fn(self.ebm)

  #   def ula_kernel(zs: Array, eps: Array):
  #     grad_zs = grad_e(images, zs)
  #     print("grad_zs", grad_zs.shape)
  #     zs = zs - self.dt * grad_zs + jnp.sqrt(2 * self.dt) * eps
  #     return zs

  #   zs = self.z_initializer()(images, batch_size=images.shape[0], train=train)
  #   all_eps = jax.random.normal(rng_key, shape=(self.num_steps, ) + zs.shape)
  #   all_zs = [zs]
  #   for i in range(self.num_steps):
  #     zs = ula_kernel(zs, all_eps[i])
  #     all_zs.append(zs)
  #   return all_zs


# def __test__sampler():
#   key = jax.random.PRNGKey(42)
#   batch_size = 3
#   image_shape = tuple((128, 128, 3))
#   z_shape = tuple((8, 32))
#   z_initializer_c = lambda: GaussianStateInit(shape=z_shape)
#   ebm_c = lambda: EBM()
#   sampler = ULASampler(z_initializer_c, ebm_c)
#   key, *subkeys = jax.random.split(key, 3)
#   batch_x = jax.random.uniform(subkeys[0],
#                                tuple([
#                                    batch_size,
#                                ]) + image_shape)
#   key, model_key, z_key, eps_key = jax.random.split(key, 4)
#   batch_z = jax.random.normal(subkeys[1],
#                               tuple([
#                                   batch_size,
#                               ]) + z_shape)
#   variables = sampler.init(
#       {
#           "params": model_key,
#           "z_state_init": z_key,
#           "eps": eps_key
#       },
#       batch_x,
#       train=False)
#   params = variables["params"]

#   @jax.jit
#   def loss_fn(params, batch_x):
#     outputs = sampler.apply({"params": params},
#                             batch_x,
#                             rngs={
#                                 "z_state_init": z_key,
#                                 "eps": eps_key
#                             },
#                             train=True)
#     return jnp.square(outputs[-1] - batch_z).mean()

#   # print(outputs)
#   for _ in range(10):
#     outputs = jax.value_and_grad(loss_fn)(params, batch_x)
#   return outputs
