from functools import partial
from typing import Any, Callable, Sequence, Tuple, Optional, Union, Dict


from flax import linen as nn
from flax.core.frozen_dict import freeze

import jax
from jax import random, jit, vmap
import jax.numpy as jnp
from jax.nn.initializers import glorot_normal, normal, zeros, constant
from jax.experimental.ode  import odeint

from .state import RotLayer



activation_fn = {
    "relu": nn.relu,
    "gelu": nn.gelu,
    "swish": nn.swish,
    "sigmoid": nn.sigmoid,
    "tanh": jnp.tanh,
    "sin": jnp.sin,
    "identity": lambda x: x,
}


def _get_activation(str):
    if str in activation_fn:
        return activation_fn[str]

    else:
        raise NotImplementedError(f"Activation {str} not supported yet!")


def _weight_fact(init_fn, mean, std):
    def init(key, shape):
        key1, key2 = random.split(key)
        w = init_fn(key1, shape)
        g = mean + normal(std)(key2, (shape[-1],))
        g = jnp.exp(g)
        v = w / g
        return g, v

    return init


class PeriodEmbs(nn.Module):
    period: Tuple[float]  # Periods for different axes
    axis: Tuple[int]  # Axes where the period embeddings are to be applied
    trainable: Tuple[
        bool
    ]  # Specifies whether the period for each axis is trainable or not

    def setup(self):
        # Initialize period parameters as trainable or constant and store them in a flax frozen dict
        period_params = {}
        for idx, is_trainable in enumerate(self.trainable):
            if is_trainable:
                period_params[f"period_{idx}"] = self.param(
                    f"period_{idx}", constant(self.period[idx] * jnp.pi), ()
                )
            else:
                period_params[f"period_{idx}"] = self.period[idx] * jnp.pi

        self.period_params = freeze(period_params)

    @nn.compact
    def __call__(self, x):
        """
        Apply the period embeddings to the specified axes.
        """
        y = []

        for i, xi in enumerate(x):
            if i in self.axis:
                idx = self.axis.index(i)
                period = self.period_params[f"period_{idx}"]
                y.extend([jnp.cos(period * xi), jnp.sin(period * xi)])
            else:
                y.append(xi)

        return jnp.hstack(y)


class FourierEmbs(nn.Module):
    scale: float
    dims: int

    @nn.compact
    def __call__(self, x):
        kernel = self.param(
            "kernel", normal(self.scale), (x.shape[-1], self.dims // 2)
        )
        y = jnp.concatenate(
            [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
        )
        return y





class Embedding(nn.Module):
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None

    @nn.compact
    def __call__(self, x):
        if self.periodicity:
            x = PeriodEmbs(**self.periodicity)(x)

        if self.fourier_embeddings:
            x = FourierEmbs(**self.fourier_embeddings)(x)

        return x


class Dense(nn.Module):
    features: int
    kernel_init: Callable = glorot_normal()
    bias_init: Callable = zeros
    reparam: Union[None, Dict] = None

    @nn.compact
    def __call__(self, x):
        if self.reparam is None:
            kernel = self.param(
                "kernel", self.kernel_init, (x.shape[-1], self.features)
            )

        elif self.reparam["type"] == "weight_fact":
            g, v = self.param(
                "kernel",
                _weight_fact(
                    self.kernel_init,
                    mean=self.reparam["mean"],
                    std=self.reparam["std"],
                ),
                (x.shape[-1], self.features),
            )
            kernel = g * v

        bias = self.param("bias", self.bias_init, (self.features,))

        y = jnp.dot(x, kernel) + bias

        return y

class RotDense(nn.Module):
    features: int
    use_rot: bool = True
    use_bias: bool = True
    dtype: jnp.dtype = jnp.float32
    kernel_init: callable = nn.initializers.lecun_normal()

    @nn.compact
    def __call__(self, x):
        in_features = x.shape[-1]
        out_features = self.features
        S = self.param("S", self.kernel_init, (out_features, in_features), self.dtype)
        b = self.param("bias", nn.initializers.zeros, (out_features,), self.dtype) if self.use_bias else None

        if self.use_rot:
            QL = self.variable("rot_state", "QL", lambda: jnp.eye(out_features, dtype=self.dtype)).value
            QR = self.variable("rot_state", "QR", lambda: jnp.eye(in_features, dtype=self.dtype)).value
            _A = self.variable("rot_state", "L", lambda: jnp.zeros((out_features, out_features), dtype=self.dtype)).value
            _B = self.variable("rot_state", "R", lambda: jnp.zeros((in_features, in_features), dtype=self.dtype)).value
            W = QL @ S @ QR.T
        else:
            W = S  # plain Dense kernel

        y = x @ W.T
        if b is not None:
            y = y + b
        return y




class FastDense(nn.Module):
    features: int
    use_rot: bool = True
    use_bias: bool = True
    dtype: jnp.dtype = jnp.float32
    kernel_init: callable = nn.initializers.lecun_normal()

    def _init_rot(self, out_features: int, in_features: int) -> RotLayer:
        dt = self.dtype
        return RotLayer(
            QL=jnp.eye(out_features, dtype=dt),
            QR=jnp.eye(in_features, dtype=dt),
            Lb=jnp.zeros((out_features, out_features), dtype=dt),
            Rb=jnp.zeros((in_features, in_features), dtype=dt),
        )

    @nn.compact
    def __call__(self, x):
        in_features = x.shape[-1]
        out_features = self.features

        S = self.param("S", self.kernel_init, (out_features, in_features), self.dtype)
        b = self.param("bias", nn.initializers.zeros, (out_features,), self.dtype) if self.use_bias else None

        if self.use_rot:
            # Single variable leaf (matches FastRotTrainState's is_rot_layer)
            rot = self.variable("rot_state", "rot", lambda: self._init_rot(out_features, in_features))

            # Use bases
            QL, QR = rot.value.QL, rot.value.QR
            W = (QL @ S) @ QR.T
        else:
            W = S

        y = x @ W.T
        if b is not None:
            y = y + b
        return y



class RotMlp(nn.Module):
    arch_name: Optional[str] = "Mlp"
    num_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    pi_init: Union[None, jnp.ndarray] = None
    use_rot: bool = True

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        #x = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(x)
        for _ in range(self.num_layers):
            x = RotDense(features=self.hidden_dim, use_rot=self.use_rot)(x)
            x = self.activation_fn(x)

        if self.pi_init is not None:
            kernel = self.param("pi_init", constant(self.pi_init), self.pi_init.shape)
            y = jnp.dot(x, kernel)

        else:
            y = RotDense(features=self.out_dim, use_rot=self.use_rot)(x)

        return x, y

class FastMlp(nn.Module):
    arch_name: Optional[str] = "Mlp"
    num_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    pi_init: Union[None, jnp.ndarray] = None
    use_rot: bool = True

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        x = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(x)

        for _ in range(self.num_layers):
            x = FastDense(features=self.hidden_dim, use_rot=self.use_rot)(x)
            x = self.activation_fn(x)        
            y = FastDense(features=self.out_dim, use_rot=self.use_rot)(x)
        return x, y
    

class Mlp(nn.Module):
    arch_name: Optional[str] = "Mlp"
    num_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    pi_init: Union[None, jnp.ndarray] = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        x = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(x)

        for _ in range(self.num_layers):
            x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
            x = self.activation_fn(x)

        if self.pi_init is not None:
            kernel = self.param("pi_init", constant(self.pi_init), self.pi_init.shape)
            y = jnp.dot(x, kernel)

        else:
            y = Dense(features=self.out_dim, reparam=self.reparam)(x)

        return x, y

###### PirateNet implementation By Wang et al.(2024) ######

class Bottleneck(nn.Module):
    hidden_dim: int
    output_dim: int
    activation: str
    reparam: Union[None, Dict]

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        identity = x

        x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        x = Dense(features=self.output_dim, reparam=self.reparam)(x)

        x = (
            x + identity
        )  # Please note that the skip connection is added before the activation function, which is the same as the original ResNet

        x = self.activation_fn(x)

        return x


class PIBottleneck(nn.Module):
    hidden_dim: int
    output_dim: int
    activation: str
    nonlinearity: float
    reparam: Union[None, Dict]

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        """
        Physics-informed bottleneck block: Add the skip connection after the activation function,
        which is different from the original ResNet, making it an identity mapping at initialization
        """
        identity = x

        x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        x = Dense(features=self.output_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        alpha = self.param("alpha", constant(self.nonlinearity), (1,))
        # alpha = jnp.exp(-alpha)

        x = alpha * x + (1 - alpha) * identity

        return x


class PIModifiedBottleneck(nn.Module):
    hidden_dim: int
    output_dim: int
    activation: str
    nonlinearity: float
    reparam: Union[None, Dict]

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x, u, v):
        identity = x

        x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        x = x * u + (1 - x) * v

        x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        x = x * u + (1 - x) * v

        x = Dense(features=self.output_dim, reparam=self.reparam)(x)
        x = self.activation_fn(x)

        alpha = self.param("alpha", constant(self.nonlinearity), (1,))
        x = alpha * x + (1 - alpha) * identity

        return x


class ResNet(nn.Module):
    arch_name: Optional[str] = "ResNet"
    num_layers: int = 2
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    pi_init: Union[None, jnp.ndarray] = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        x = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(x)

        for _ in range(self.num_layers):
            x = Bottleneck(
                hidden_dim=self.hidden_dim,
                output_dim=x.shape[-1],
                activation=self.activation,
                reparam=self.reparam,
            )(x)

        y = Dense(features=self.out_dim, reparam=self.reparam)(x)

        return x, y


class PIResNet(nn.Module):
    arch_name: Optional[str] = "PIResNet"
    num_layers: int = 2
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    nonlinearity: float = 0.0
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    pi_init: Union[None, jnp.ndarray] = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        x = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(x)

        for _ in range(self.num_layers):
            x = PIBottleneck(
                hidden_dim=self.hidden_dim,
                output_dim=x.shape[-1],
                activation=self.activation,
                nonlinearity=self.nonlinearity,
                reparam=self.reparam,
            )(x)

        if self.pi_init is not None:
            kernel = self.param("pi_init", constant(self.pi_init), self.pi_init.shape)
            y = jnp.dot(x, kernel)

        else:
            y = Dense(features=self.out_dim, reparam=self.reparam)(x)

        return x, y


class PirateNet(nn.Module):
    arch_name: Optional[str] = "PirateNet"
    num_layers: int = 2
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    nonlinearity: float = 0.0
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    pi_init: Union[None, jnp.ndarray] = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        embs = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(x)
        x = embs

        u = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        u = self.activation_fn(u)

        v = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        v = self.activation_fn(v)

        for _ in range(self.num_layers):
            x = PIModifiedBottleneck(
                hidden_dim=self.hidden_dim,
                output_dim=x.shape[-1],
                activation=self.activation,
                nonlinearity=self.nonlinearity,
                reparam=self.reparam,
            )(x, u, v)

        
        y = Dense(features=self.out_dim, reparam=self.reparam)(x)

        return x, y


class ModifiedMlp(nn.Module):
    arch_name: Optional[str] = "ModifiedMlp"
    num_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    pi_init: Union[None, jnp.ndarray] = None
    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        x = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(x)

        u = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
        v = Dense(features=self.hidden_dim, reparam=self.reparam)(x)

        u = self.activation_fn(u)
        v = self.activation_fn(v)

        for _ in range(self.num_layers):
            x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
            x = self.activation_fn(x)
            x = x * u + (1 - x) * v

        if self.pi_init is not None:
            kernel = self.param("pi_init", constant(self.pi_init), self.pi_init.shape)
            y = jnp.dot(x, kernel)

        else:
            y = Dense(features=self.out_dim, reparam=self.reparam)(x)

        return x, y
    

def _act_or_identity(name: str):
    return (lambda x: x) if name == "identity" else _get_activation(name)



class PIRFBlock(nn.Module):
    d_model: int
    hidden_dim: Optional[int] = None
    activation: str = "tanh"
    reparam: Union[None, Dict] = None

    def setup(self):
        self.act1 = _get_activation(self.activation)
        self.act2 = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x, y):
        hd = self.hidden_dim or self.d_model
        x_proj = Dense(features=hd, reparam=self.reparam)
        y_proj = Dense(features=hd, reparam=self.reparam)
        probe  = Dense(features=self.d_model, reparam=self.reparam)
        skip = x
        y = self.act1(y_proj(x))
        x = self.act2(x_proj(x))
        x = probe(x * y)
        return x + skip

class PRFScaleShiftBlock(nn.Module):
    d_model: int
    hidden_dim: Optional[int] = None
    activation: str = "tanh"
    reparam: Union[None, Dict] = None

    def setup(self):
        self.act1 = _get_activation(self.activation)
        self.act2 = _get_activation(self.activation)
        self.act3 = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x, y=None):
        hd = self.hidden_dim or self.d_model
        s_proj = Dense(features=hd, reparam=self.reparam, name="s_proj")
        b_proj = Dense(features=hd, reparam=self.reparam, name="b_proj")
        x_proj = Dense(features=hd, reparam=self.reparam, name="x_proj")
        s_head = Dense(features=self.d_model, reparam=self.reparam, name="s_head")
        b_head = Dense(features=self.d_model, reparam=self.reparam, name="b_head")
        out_head = Dense(features=self.d_model, reparam=self.reparam, name="out_head")

        skip = x
        s = self.act1(s_proj(x))
        b = self.act2(b_proj(x))
        h = self.act3(x_proj(x))

        s = s_head(s)
        b = b_head(b)

        y = h * s + b
        y = out_head(y)
        return y + skip 


class PIRF(nn.Module):
    arch_name: Optional[str] = "pirf"
    in_dim: int = 2
    out_dim: int = 1
    d_model: int = 128
    hidden_dim: Optional[int] = None
    depth: int = 4
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    kernel_config: Optional[Dict] = None
    pi_init: Union[None, jnp.ndarray] = None
    block_type: str = "scale"  # "scale" (PIRFBlock) or "scale_shift" (PRFScaleShiftBlock)

    def setup(self):
        kc = self.kernel_config or {}
        self.alpha = kc.get("alpha", 1.0)
        self.beta  = kc.get("beta", 1.0)
        basis_name = kc.get("basis", self.activation)
        preact_name = kc.get("preactivation", "identity")
        self.activation_fn = _get_activation(self.activation)
        self.basis_fn = _get_activation(basis_name)
        self.preact_fn = _act_or_identity(preact_name)
        assert (self.alpha > 0.0) or (preact_name != "identity")

    @nn.compact
    def __call__(self, x, *args):
        if len(args) > 0:
            t = jnp.concatenate(args, axis=-1)
            feats = jnp.concatenate([x, t], axis=-1)
        else:
            feats = x
        feats = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(feats)
        feats = Dense(features=self.d_model, reparam=self.reparam)(feats)
        Block = PIRFBlock if self.block_type == "scale" else PRFScaleShiftBlock
        for i in range(self.depth):
            feats = self.beta * Block(
                d_model=self.d_model,
                hidden_dim=self.hidden_dim or self.d_model,
                activation=self.activation,
                reparam=self.reparam,
                name=f"block_{i}",
            )(feats, feats)
        feats = self.basis_fn(feats)
        h = Dense(features=self.hidden_dim or self.d_model, reparam=self.reparam, name="out_0")(feats)
        h = self.activation_fn(h)
        out = Dense(features=self.out_dim, reparam=self.reparam, name="out_1")(h)
        return feats, out
    


class _Dynamics(nn.Module):
    d_model: int
    hidden_dim: int
    mlp_layers: int
    activation: str
    reparam: Union[None, Dict]
    use_time: bool
    time_preactivation: str

    def setup(self):
        self.time_preact = _act_or_identity(self.time_preactivation)

    @nn.compact
    def __call__(self, z, t):
        if self.use_time:
            t_val = self.time_preact(jnp.asarray(t, dtype=z.dtype))
            t_feat = jnp.asarray([t_val]) if z.ndim == 1 else jnp.broadcast_to(t_val, (z.shape[0], 1))
            inp = jnp.concatenate([z, t_feat], axis=-1)
        else:
            inp = z
        drift = MlpBlock(
            num_layers=self.mlp_layers,
            hidden_dim=self.hidden_dim,
            out_dim=self.d_model,
            activation=self.activation,
            reparam=self.reparam,
            final_activation=True,
            name="drift",
        )(inp)
        return drift




class OPRF(nn.Module):
    arch_name: Optional[str] = "oprf"
    in_dim: int = 2
    out_dim: int = 1
    d_model: int = 128
    hidden_dim: Optional[int] = None
    mlp_layers: int = 2
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None
    t0: float = 0.0
    t1: float = 1.0
    num_steps: int = 8
    use_time: bool = True
    time_preactivation: str = "identity"
    pi_init: Union[None, jnp.ndarray] = None
    method: str = "rk4"

    def setup(self):
        self.hd = self.hidden_dim or self.d_model
        self.activation_fn = _get_activation(self.activation)
        self.time_preact = (lambda x: x) if self.time_preactivation == "identity" else _get_activation(self.time_preactivation)
        self.use_euler = (self.method != "rk4")

    @nn.compact
    def __call__(self, x, *args):
        if len(args) > 0:
            t_extra = jnp.concatenate(args, axis=-1)
            feats = jnp.concatenate([x, t_extra], axis=-1)
        else:
            feats = x

        feats = Embedding(periodicity=self.periodicity, fourier_embeddings=self.fourier_embeddings)(feats)
        feats = Dense(features=self.d_model, reparam=self.reparam)(feats)

        drift = MlpBlock(
            num_layers=self.mlp_layers,
            hidden_dim=self.hd,
            out_dim=self.d_model,
            activation=self.activation,
            reparam=self.reparam,
            final_activation=True,
            name="drift",
        )

        def append_time(z, t_scalar):
            if not self.use_time:
                return z
            t_val = self.time_preact(jnp.asarray(t_scalar, dtype=z.dtype))
            if z.ndim == 1:
                t_feat = jnp.asarray([t_val], dtype=z.dtype)
            else:
                t_feat = jnp.broadcast_to(t_val, (z.shape[0], 1))
            return jnp.concatenate([z, t_feat], axis=-1)

        def f(params, z, t_scalar):
            return drift.apply({"params": params}, append_time(z, t_scalar))

        initializing = ("drift" not in self.variables.get("params", {}))
        if initializing:
            _ = drift(append_time(feats, self.t0))
            z_T = feats
        else:
            drift_params = self.variables["params"]["drift"]
            dt = (self.t1 - self.t0) / float(self.num_steps)

            def euler_step(params, z, t):
                return z + dt * f(params, z, t)

            def rk4_step(params, z, t):
                k1 = f(params, z, t)
                k2 = f(params, z + 0.5 * dt * k1, t + 0.5 * dt)
                k3 = f(params, z + 0.5 * dt * k2, t + 0.5 * dt)
                k4 = f(params, z + dt * k3, t + dt)
                return z + (dt / 6.0) * (k1 + 2 * k2 + 2 * k3 + k4)

            step = rk4_step if self.method == "rk4" else euler_step
            step = jax.checkpoint(step)

            def body(i, z):
                t_i = self.t0 + dt * i
                return step(drift_params, z, t_i)

            z_T = jax.lax.fori_loop(0, self.num_steps, body, feats)

        if self.pi_init is not None:
            kernel = self.param("pi_init", constant(self.pi_init), self.pi_init.shape)
            y = jnp.dot(z_T, kernel)
        else:
            h = Dense(features=self.hd, reparam=self.reparam, name="out_0")(z_T)
            h = self.activation_fn(h)
            y = Dense(features=self.out_dim, reparam=self.reparam, name="out_1")(h)

        return z_T, y




#################################################################################################
#################################### neural operators ###########################################
#################################################################################################

class MlpBlock(nn.Module):
    num_layers: int
    hidden_dim: int
    out_dim: int
    activation: str
    reparam: Union[None, Dict]
    final_activation: bool

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, x):
        for _ in range(self.num_layers):
            x = Dense(features=self.hidden_dim, reparam=self.reparam)(x)
            x = self.activation_fn(x)

        x = Dense(features=self.out_dim, reparam=self.reparam)(x)
        if self.final_activation:
            x = self.activation_fn(x)

        return x


class DeepONet(nn.Module):
    arch_name: Optional[str] = "DeepONet"
    num_branch_layers: int = 4
    num_trunk_layers: int = 4
    hidden_dim: int = 256
    out_dim: int = 1
    activation: str = "tanh"
    periodicity: Union[None, Dict] = None
    fourier_embeddings: Union[None, Dict] = None
    reparam: Union[None, Dict] = None

    def setup(self):
        self.activation_fn = _get_activation(self.activation)

    @nn.compact
    def __call__(self, u, x):
        u = MlpBlock(
            num_layers=self.num_branch_layers,
            hidden_dim=self.hidden_dim,
            out_dim=self.hidden_dim,
            activation=self.activation,
            final_activation=False,
            reparam=self.reparam,
        )(u)

        x = Mlp(
            num_layers=self.num_trunk_layers,
            hidden_dim=self.hidden_dim,
            out_dim=self.hidden_dim,
            activation=self.activation,
            periodicity=self.periodicity,
            fourier_embeddings=self.fourier_embeddings,
            reparam=self.reparam,
        )(x)

        y = u * x
        y = self.activation_fn(y)
        y = Dense(features=self.out_dim, reparam=self.reparam)(y)
        return y
