from dataclasses import dataclass
import jax
import jax.numpy as jnp
from flax import linen as nn

@dataclass
class TimeEmbedding(nn.Module):
    embedding_layers: int
    n_layers: int

    @nn.compact
    def __call__(self, t):
        for _ in range(self.n_layers):
            t_features = nn.Dense(self.embedding_layers)(t) # (B, dim)
            t_features = nn.swish(t_features)
        t_features = nn.Dense(1)(t_features)  # (B, 1)
        t_features = jnp.concatenate([t, t_features], axis=-1)  # (B, 2)

        return t_features  # (B, 2)

@dataclass
class MLP(nn.Module):
    hidden_layers: int
    embedding_layers: int
    n_layers: int

    def reshape_inputs(self, pos, t):
        # Normalize pos to shape (B, D)
        assert pos.ndim in (0, 1, 2), "Input pos should be of shape (BS, D) or (D,) or ()"
        if pos.ndim == 0:
            pos = pos[None, None]  # () -> (1, 1)
        elif pos.ndim == 1:
            pos = pos[None, :]  # (D,) -> (1, D)

        # Normalize t to shape (B, 1)
        assert t.ndim in (0, 1, 2), "Input t should be of shape (), (BS, 1) or (BS,)"
        if t.ndim == 1:
            t = t[:, None]  # (BS) -> (BS, 1)
        elif t.ndim == 0:
            t = t[None, None]  # () -> (1, 1)

        return pos, t

    @nn.compact
    def __call__(self, x, t):
        x, t = self.reshape_inputs(x, t)

        # Feed Forward
        t_emb = TimeEmbedding(self.embedding_layers, self.n_layers)(t)
        h = jnp.concatenate([x, t_emb], axis=-1)
        for _ in range(self.n_layers):
            h = nn.Dense(self.hidden_layers)(h)
            h = nn.swish(h)
        output = nn.Dense(x.shape[-1])(h)

        return output