import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from flax import linen as nn
from flax.linen import tabulate
import re
import os

def init_params_opt(args, model):
    nn_input_dim = args.nn_input_dim
    key = args.key
    dummy_x = jnp.ones((1, 1))
    dummy_y = jnp.ones((1, nn_input_dim))
    variables = model.init(key, dummy_x, dummy_y)
    nn_params = variables['params']


    summary = tabulate(model, rngs={"params": key})(dummy_x, dummy_y)
    clean_summary = re.sub(r'\x1b\[[0-9;]*m', '', summary)
    with open( os.path.join( args.RESULT_PATH , "angle_net_struc.txt") , "w") as f:
        f.write(clean_summary)


    jaxpr = jax.make_jaxpr(model.apply)({"params": nn_params}, dummy_x, dummy_y)
    with open( os.path.join( args.RESULT_PATH , "angle_net_jaxpr.txt"), "w") as f:
        f.write(str(jaxpr))

    return nn_params

class JAXMLP(nn.Module):
    output_dim: int
    N: int
    def setup(self):
        self.dtype = jnp.float64
        self.widths = [16*self.N
                       , 128*self.N
                       , 256*self.N
                       , 256*self.N
                       , 128*self.N
                       , 16*self.N
                       ]

    @nn.compact
    def __call__(self, pde_param, forcing):
        x = jnp.concatenate([jnp.sqrt(pde_param), forcing], axis=-1)
        for w in self.widths:
            x = nn.Dense(features=w, dtype=self.dtype, param_dtype=self.dtype)(x)
            x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, dtype=self.dtype, param_dtype=self.dtype)(x)
        return x