import jax
import jax.numpy as jnp
from flax import linen as nn

def init_params_opt(args, model):
    nn_input_dim = args.nn_input_dim
    key = args.key

    dummy_x = jnp.ones((1, nn_input_dim))
    variables = model.init(key, dummy_x)
    nn_params = variables['params']
    return nn_params



class JAXMLP(nn.Module):
    output_dim:int
    N: int
    @nn.compact
    def __call__(self, x):
        '''
            StronglyEntanglingLayers : (3, n_qubits, n_layers)
            BasicEntanglerLayers     : (n_layers, n_qubits)
        '''
        x = nn.Dense(features= int(3 * self.N) )(x)
        x = nn.relu(x)
        x = nn.Dense(features= int(3 * self.N) )(x)
        x = nn.relu(x)
        x = nn.Dense(features= self.output_dim)(x) 
        return x


class ClassicMLP(nn.Module):
    N: int
    @nn.compact
    def __call__(self, x):
        '''
            StronglyEntanglingLayers : (3, n_qubits, n_layers)
            BasicEntanglerLayers     : (n_layers, n_qubits)
        '''
        x = nn.Dense(features= int(3 * self.N) )(x)
        x = nn.relu(x)
        x = nn.Dense(features= int(3 * self.N) )(x)
        x = nn.relu(x)
        x = nn.Dense(features= self.N-1)(x) 
        return x