import numpy as np
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
import distrax


class ActorCritic(nn.Module):
    action_dim: int
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        act = nn.relu if self.activation == "relu" else nn.tanh
        h = act(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x))
        h = act(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(h))
        mean = nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(h)
        log_std = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(mean, jnp.exp(log_std))
        v = act(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x))
        v = act(nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(v))
        v = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(v)
        return pi, jnp.squeeze(v, -1)


class AdvNet(nn.Module):
    use_action: bool = True

    @nn.compact
    def __call__(self, obs, act=None):
        x = jnp.concatenate([obs, act], -1) if self.use_action else obs
        h = nn.tanh(nn.Dense(128, kernel_init=nn.initializers.lecun_normal())(x))
        h = nn.tanh(nn.Dense(128, kernel_init=nn.initializers.lecun_normal())(h))
        eta = nn.softplus(nn.Dense(1, kernel_init=nn.initializers.lecun_normal())(h))
        return jnp.squeeze(eta, -1)