from typing import Callable, Sequence
import haiku as hk
import jax.numpy as jnp
import numpy as np


class SuperMLP(hk.Module):

  def __init__(self, hidden: Sequence[int],
               activation: Callable[[jnp.ndarray], jnp.ndarray],
               activate_final: bool = False,
               normalize: bool = False,
               spectral_norm: bool = False,
               residual: bool = False, name=None):
    super().__init__(name=name)

    self._hidden = hidden
    self._activation = activation
    self._activate_final = activate_final
    self._normalize = normalize
    self._residual = residual

  def __call__(self, x, conditional=None, is_training=True):
    for i, size in enumerate(self._hidden):
      if conditional is not None:
        x = jnp.concatenate([x, conditional], axis=-1)
      h = hk.Linear(size)(x)

      if i < len(self._hidden)-1 or self._activate_final:
        if self._normalize:
          h = hk.LayerNorm(-1, True, True)(h)
        h = self._activation(h)
      else:
        pass

      if self._residual:
        if size != x.shape[1]:
          x = hk.Linear(size)(x)

        x += h
      else:
        x = h

    return x
    
def potential_net(x, y, hidden_size):
    z = jnp.concatenate([x, y], axis=-1)
    h = utils.SuperMLP([hidden_size * 2, hidden_size, 32], activation=jax.nn.tanh,
                       activate_final=False, residual=True,
                       normalize=True)(z)
    h = jnp.square(h)
    return jnp.mean(h, axis=1)