import jax
import jax.numpy as jnp
from flax import linen as nn

def init_params_opt(args, model):
    key = args.key
    nn_input_dim = args.input_dim
    dummy_x = jnp.ones((1, nn_input_dim))
    variables = model.init(key, dummy_x)
    nn_params = variables['params']
    return nn_params

class JAXMLP(nn.Module):
    input_dim:int
    output_dim:int
    N: int
    @nn.compact
    def __call__(self, x):
        x = x.reshape(-1, self.input_dim)
        x = nn.Dense(features= int(3 * self.input_dim) )(x)
        x = nn.relu(x)
        x = nn.Dense(features= int(3 * self.input_dim) )(x)
        x = nn.relu(x)
        x = nn.Dense(features= self.output_dim)(x) 
        return x


class JAXCNN(nn.Module):
    N: int
    output_dim: int
    @nn.compact
    def __call__(self, x):
        """
        Input: (-1, N+1, N+1, 1)
        We'll treat it as grayscale images → add channel dim.
        """
        x = x.reshape(-1, self.N+1, self.N+1, 1)

        # Conv blocks
        x = nn.Conv(features=4, kernel_size=(3, 3), strides=(1, 1), padding="SAME")(x)
        # x = nn.relu(x)
        x = nn.gelu(x)

        x = nn.Conv(features=16, kernel_size=(3, 3), strides=(2, 2), padding="SAME")(x)
        # x = nn.relu(x)
        x = nn.gelu(x)

        # Flatten
        x = x.reshape((x.shape[0], -1))  

        # Dense layers
        x = nn.Dense(features=256)(x)
        # x = nn.relu(x)
        x = nn.gelu(x)

        x = nn.Dense(features=512)(x)
        # x = nn.relu(x)
        x = nn.gelu(x)
        x = nn.Dense(features=self.output_dim)(x)
        return x