import haiku as hk
import jax
import jax.numpy as np
from jax.nn import relu
from jax import Array
from jax.scipy.special import logsumexp

class FixedPolicy():
  def __init__(self, fixedfunc: callable) -> None:
    self.func = fixedfunc
    self.predict_batch = jax.vmap(fixedfunc, (0,))

  def predict(self, input) -> Array:
    return self.func(input)

  def __call__(self, s) -> Array:
    return self.predict_batch(s)

class NNPolicy(hk.Module):
  
  def __init__(self, sizes, sigma_weights, sigma_biases, name=None):
    super().__init__(name=name)
    self.sizes = sizes
    self.sigma_weights = sigma_weights
    self.sigma_biases = sigma_biases

  def __call__(self, x) -> Array:
    for i, (shapein, shapeout) in enumerate(zip(self.sizes[:-2], self.sizes[1:-1])):
      w = hk.get_parameter(f"w_{i}", shape=[shapeout, shapein], dtype=x.dtype, init=hk.initializers.RandomNormal())
      b = hk.get_parameter(f"b_{i}", shape=[shapeout], dtype=x.dtype, init=hk.initializers.RandomNormal())

      h = np.dot(self.sigma_weights / np.sqrt(shapeout) * w, x) + self.sigma_biases * b
      x = relu(h)

    final_w = hk.get_parameter(f"w_{i+1}", shape=[self.sizes[-1], self.sizes[-2]], dtype=x.dtype, init=hk.initializers.RandomNormal())
    final_b = hk.get_parameter(f"b_{i+1}", shape=[self.sizes[-1]], dtype=x.dtype, init=hk.initializers.RandomNormal())

    logits = np.dot(final_w, x) + final_b
    return logits - logsumexp(logits)

def get_nnpolicy(network_size, sigmaweights, sigmabiases):
  def pinetwork_forward(x):
    nn = NNPolicy(network_size, sigmaweights, sigmabiases)
    return nn(x)

  pinetwork = hk.transform(pinetwork_forward)
  pinetwork_batch_apply = jax.vmap(pinetwork.apply, (None, None, 0))

  return pinetwork, pinetwork_batch_apply
