import jax.numpy as jnp
import flax.linen as nn
import numpy as np
from flax.linen.initializers import constant, orthogonal
from typing import Callable
import distrax
from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper # Use Gymnax wrappers

# Use a smaller network for MinAtar
class ActorCritic(nn.Module):
    action_dim: int # Discrete action space size
    activation: Callable = nn.relu # Or nn.relu

    @nn.compact
    def __call__(self, x):
        activation = self.activation
        # Actor head
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_logits = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_logits)

        # Critic head
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)

class Discriminator(nn.Module):
    activation: Callable
    use_spectral_norm: bool = False
    @nn.compact
    def __call__(self, x, train: bool = True):
        activation = self.activation
        hidden_dim = 64
        if self.use_spectral_norm:
            layer1 = nn.SpectralNorm(nn.Dense(hidden_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)))
            layer2 = nn.SpectralNorm(nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0)))

            x = layer1(x, update_stats=train)
            x = activation(x)
            x = layer2(x, update_stats=train)
        else:
            x = nn.Dense(hidden_dim, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x)
            x = activation(x)
            x = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(x)

        return x