import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
from typing import Sequence, Callable
import distrax

class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: Callable

    @nn.compact
    def __call__(self, x):
        activation = self.activation
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))

        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)
    
class Discriminator(nn.Module):
    activation: Callable
    use_spectral_norm: bool = False

    @nn.compact
    def __call__(self, x, train: bool = True):
        activation = self.activation
        
        if self.use_spectral_norm:
            x = nn.SpectralNorm(nn.Dense(128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)))(x, update_stats=train)
            x = activation(x)
            x = nn.SpectralNorm(nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0)))(x, update_stats=train)
        else:
            x = nn.Dense(128, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
            x = activation(x)
            x = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(x)
        
        return x

class Actor(nn.Module):
    action_dim: Sequence[int]
    activation: Callable

    @nn.compact
    def __call__(self, x):
        activation = self.activation
        actor_mean = nn.Dense(
            512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        actor_logtstd = self.param("log_std", nn.initializers.zeros, (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))

        return pi
    

    