import distrax
import numpy as np
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from typing import Sequence



class ActorCriticCont(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"
    normalize: bool = False

    @nn.compact
    def __call__(self, x, calculate_norm=False):
        norm_mean = self.variable(
            "norm_stats", "mean", lambda s: jnp.zeros(s, jnp.float32), x.shape
        )
        norm_var = self.variable(
            "norm_stats", "var", lambda s: jnp.ones(s, jnp.float32), x.shape
        )
        norm_count = self.variable("norm_stats", "count", lambda s: jnp.array([s]), 0.0)
        if not self.is_initializing() and calculate_norm:
            batch_mean = jnp.mean(x, axis=0)
            batch_var = jnp.var(x, axis=0)
            batch_count = x.shape[0]

            delta = batch_mean - norm_mean.value
            tot_count = norm_count.value + batch_count

            norm_mean.value = norm_mean.value + delta * batch_count / tot_count
            m_a = norm_var.value * norm_count.value
            m_b = batch_var * batch_count
            M2 = (
                m_a
                + m_b
                + jnp.square(delta) * norm_count.value * batch_count / tot_count
            )
            norm_var.value = M2 / tot_count
            norm_count.value = tot_count
        if self.normalize:
            norm_x = (x - norm_mean.value) / jnp.sqrt(norm_var.value + 1e-8)
        else:
            norm_x = x
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(norm_x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))

        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)
