import jax.numpy as jnp
import haiku as hk
import jax
import chex 
from policies.layers import SNLinear


kaimin_init = {'w_init': hk.initializers.VarianceScaling(2.0, "fan_in",  "truncated_normal"),
               'b_init': hk.initializers.Constant(0.)}

class SkipConnection(hk.Module):
  def __init__(self, dims, use_spectral_norm=False, use_layer_norm=False, use_bias=True, use_skip=True, activation=jax.nn.relu):
    super().__init__()
    self.use_skip = use_skip
    self.activation = activation
    self.use_layer_norm = use_layer_norm
    Dense = SNLinear if use_spectral_norm else hk.Linear
    self.linear = Dense(dims,
                        with_bias=use_bias,
                        **kaimin_init
                        )
    self.linear2 = Dense(dims,
                         with_bias=use_bias,
                         **kaimin_init
                         )
  
  def __call__(self, x):
    x_start_block = x
    x = self.activation(x)
    if self.use_layer_norm:
      x = hk.LayerNorm(-1, create_scale=True, create_offset=True, name='nosn_ln')(x)
    x = self.linear(x)
    x = self.activation(x)
    if self.use_layer_norm:
      x = hk.LayerNorm(-1, create_scale=True, create_offset=True, name='nosn_ln')(x)
    x = self.linear2(x)
    chex.assert_equal_shape([x, x_start_block])
    if self.use_skip:
      x = x + x_start_block
    #if self.use_layer_norm:
    #  x = hk.LayerNorm(-1, create_scale=True, create_offset=True, name='nosn_ln')(x)
    return x

class energy_network_:
  def __init__(self, num_layers=8, dims=512, use_skip=True, use_spectral_norm=False, use_layer_norm=True, use_bias=True, action_dims=1, activation=jax.nn.relu, rtg=True):
    self.layers = []
    self.activation = activation
    self.use_layer_norm = use_layer_norm
    self.dims = dims
    self.use_skip = use_skip
    self.rtg = rtg
    Dense = SNLinear if use_spectral_norm else hk.Linear

    self.embed_layer = Dense(dims, 
                             **kaimin_init,
                             #name='nosn'
                             )

    for _ in range(num_layers//2):
      self.layers.append(SkipConnection(dims, use_spectral_norm, use_layer_norm, use_bias, activation))
    
    self.projection_layer = Dense(action_dims + 1, 
                                  with_bias=False,
                                  **kaimin_init,
                                  #name='nosn'
                                  )

  def __call__(self, obs: jnp.ndarray, action: jnp.ndarray, returns: jnp.ndarray, u_net=False):
    flat_obs = hk.Flatten()(obs)
    assert returns.shape[-1] == 1
    delta_inputs = jnp.concatenate((action, returns), -1)
    if self.rtg:
      inputs = jnp.concatenate((flat_obs, action, returns), -1)
    else:
      inputs = delta_inputs

    x = self.embed_layer(inputs)

    for layer in self.layers:
      x = layer(x)

    if not self.use_skip:
      x = self.activation(x)

    energy = self.projection_layer(x)

    if u_net:
      energy = (energy - delta_inputs)**2

    energy = energy.sum(axis=-1, keepdims=True)
    return energy


def energy_network(obs: jnp.ndarray,
                   action: jnp.ndarray, 
                   returns: jnp.ndarray,
                   use_skip=True, 
                   u_net=False, 
                   num_layers=8, 
                   dims=512, 
                   use_spectral_norm=False, 
                   use_layer_norm=True, 
                   action_dims=1, 
                   use_bias=True, 
                   temperature=1.,
                   activation=jax.nn.swish,
                   rtg=True,):

  en = energy_network_(num_layers=num_layers,
                       dims=dims,
                       use_skip=use_skip,
                       use_spectral_norm=use_spectral_norm,
                       use_layer_norm=use_layer_norm,
                       use_bias=use_bias,
                       action_dims=action_dims,
                       activation=activation,
                       rtg=rtg)

  return en(obs, action, returns, u_net=u_net)


def policy_network(obs: jnp.ndarray, dims=512, out_dims=1, use_layer_norm=True) -> jnp.ndarray:
    flat_obs = hk.Flatten()(obs)
    x = hk.Linear(dims, 
                  w_init=hk.initializers.RandomNormal(0.05),
                  b_init=hk.initializers.Constant(0.),
                  name='nosn')(flat_obs)

    x = jax.nn.relu(x)

    x = hk.Linear(dims, 
                  w_init=hk.initializers.RandomNormal(0.05),
                  b_init=hk.initializers.Constant(0.),
                  name='nosn')(x)

    x = jax.nn.relu(x)
                  
    x = hk.Linear(out_dims,
                  w_init=hk.initializers.RandomNormal(0.05),
                  b_init=hk.initializers.RandomNormal(0.05),
                  name='nosn')(x)
    return x

def q_network(obs: jnp.ndarray, action: jnp.ndarray, dims=512, out_dims=1, use_layer_norm=True) -> jnp.ndarray:
    flat_obs = hk.Flatten()(obs)
    inputs = jnp.concatenate((flat_obs, action), -1)
    embed = hk.Linear(dims, 
                      w_init=hk.initializers.RandomNormal(0.05),
                      b_init=hk.initializers.RandomNormal(0.05),
                      name='nosn')(inputs)
    x = SkipConnection(dims, use_layer_norm)(embed)
    x = SkipConnection(dims, use_layer_norm)(x)
    x = SkipConnection(dims, use_layer_norm)(x)
    x = SkipConnection(dims, use_layer_norm)(x)
    if use_layer_norm:
      x = hk.LayerNorm(-1, True, True, name='nosn_ln')(x)
    x = hk.Linear(out_dims,
                  w_init=hk.initializers.RandomNormal(0.05),
                  b_init=hk.initializers.Constant(0.0),
                  name='nosn')(x)
    return x


def rcp_network(obs: jnp.ndarray, returns: jnp.ndarray, dims=512, out_dims=1, use_layer_norm=False, rtg=True) -> jnp.ndarray:
    flat_obs = hk.Flatten()(obs)

    if rtg:
      inputs = jnp.concatenate((flat_obs, returns), -1)
    else:
      inputs = flat_obs

    x = hk.Linear(dims, 
                  w_init=hk.initializers.RandomNormal(0.05),
                  b_init=hk.initializers.RandomNormal(0.05),
                  name='nosn')(inputs)
    x = jax.nn.relu(x)

    x = hk.Linear(dims, 
                  w_init=hk.initializers.RandomNormal(0.05),
                  b_init=hk.initializers.RandomNormal(0.05),
                  name='nosn')(x)

    x = jax.nn.relu(x)
    x = hk.Linear(out_dims,
                  w_init=hk.initializers.RandomNormal(0.05),
                  b_init=hk.initializers.RandomNormal(0.05),
                  name='nosn')(x)
    return x
