import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Callable, Optional, Tuple, Union

Array = jax.Array


class MLP(nn.Module):
    """
    MLP that works on inputs of shape (..., in_dim) and returns (..., out_dim).
    Uses smooth activation (tanh) which is typically friendlier for jacfwd/jacrev.
    """
    in_dim: int
    hidden_dim: int
    out_dim: int
    num_hidden_layers: int = 2
    activation: Callable[[Array], Array] = nn.tanh

    @nn.compact
    def __call__(self, x: Array) -> Array:
        # Dense in Flax works on the last axis and preserves leading dims.
        for i in range(self.num_hidden_layers):
            x = nn.Dense(self.hidden_dim, use_bias=True, name=f"linear_{i}")(x)
            x = self.activation(x)
        x = nn.Dense(self.out_dim, use_bias=True, name="linear_out")(x)
        return x


class CompleteGraphLayer(nn.Module):
    """
    One message-passing layer on a complete graph (optionally without self-loops).

    h: (..., N, node_dim)
    e: (..., N, N, edge_dim)
    node_mask: (..., N) float {0,1} (1 for valid node indices < n_nodes)
    """
    n_max: int
    node_dim: int
    edge_dim: int
    hidden_dim: int
    msg_dim: Optional[int] = None
    num_mlp_hidden_layers: int = 2
    exclude_self: bool = True
    residual: bool = True

    def setup(self):
        md = self.node_dim if self.msg_dim is None else self.msg_dim
        self._msg_dim = md

        # message: m_{i<-j} = MLP([h_i, h_j, e_{ij}])
        self.msg_mlp = MLP(
            in_dim=2 * self.node_dim + self.edge_dim,
            hidden_dim=self.hidden_dim,
            out_dim=md,
            num_hidden_layers=self.num_mlp_hidden_layers,
            activation=nn.tanh,
        )

        # update: delta_i = MLP([h_i, sum_j m_{i<-j}])
        self.upd_mlp = MLP(
            in_dim=self.node_dim + md,
            hidden_dim=self.hidden_dim,
            out_dim=self.node_dim,
            num_hidden_layers=self.num_mlp_hidden_layers,
            activation=nn.tanh,
        )

        # Precompute (n_max, n_max) mask for excluding self messages, like your buffer.
        if self.exclude_self:
            not_self = 1.0 - jnp.eye(self.n_max, dtype=jnp.float32)
        else:
            not_self = jnp.ones((self.n_max, self.n_max), dtype=jnp.float32)
        self.not_self_mask = not_self  # constant, not a parameter

    def __call__(self, h: Array, e: Array, node_mask: Array) -> Array:
        # h: (..., N, D)
        # e: (..., N, N, E)
        # node_mask: (..., N)
        N = h.shape[-2]
        D = h.shape[-1]

        # pair_mask: (..., N, N) where valid iff both endpoints are valid
        pair_mask = node_mask[..., :, None] * node_mask[..., None, :]  # (..., N, N)

        # exclude self if requested (broadcasts across leading dims)
        pair_mask = pair_mask * self.not_self_mask[:N, :N].astype(pair_mask.dtype)

        # Zero out invalid edges/features pre-MLP (helps prevent bias leakage upstream)
        e = e * pair_mask[..., None]  # (..., N, N, E)

        # Expand h_i and h_j to (..., N, N, D) for concatenation
        batch_shape = h.shape[:-2]
        h_i = jnp.broadcast_to(h[..., :, None, :], batch_shape + (N, N, D))
        h_j = jnp.broadcast_to(h[..., None, :, :], batch_shape + (N, N, D))

        # Build pair features and compute messages
        pair_feat = jnp.concatenate([h_i, h_j, e], axis=-1)  # (..., N, N, 2D+E)
        msg = self.msg_mlp(pair_feat)                        # (..., N, N, msg_dim)

        # Zero invalid messages
        msg = msg * pair_mask[..., None]                     # (..., N, N, msg_dim)

        # Aggregate incoming messages by summation over senders j
        agg = msg.sum(axis=-2)                               # (..., N, msg_dim)

        # Update nodes
        upd_in = jnp.concatenate([h, agg], axis=-1)           # (..., N, D+msg_dim)
        delta = self.upd_mlp(upd_in)                         # (..., N, D)

        h_new = (h + delta) if self.residual else delta

        # Keep padded nodes exactly zero
        h_new = h_new * node_mask[..., None]                 # (..., N, D)
        return h_new


class CompleteGraphGNN(nn.Module):
    """
    Stacked complete-graph message passing.

    Inputs:
      node_x: (..., n_max, k)
      edge_x: (..., n_max, n_max, l)
      n_nodes: (...) int tensor or python int, number of valid nodes (<= n_max)

    Outputs:
      node_out: (..., n_max, out_dim)  (padded nodes are exactly zero)
      graph_out: (..., out_dim)        (masked sum-pooling over nodes)
    """
    n_max: int
    node_in_dim: int
    edge_in_dim: int
    hidden_dim: int = 128
    num_layers: int = 3
    out_dim: int = 128
    msg_dim: Optional[int] = None
    num_mlp_hidden_layers: int = 2
    exclude_self: bool = True
    residual: bool = True

    def setup(self):
        # Encode inputs into hidden space
        self.node_enc = nn.Dense(self.hidden_dim, use_bias=True, name="node_enc")
        self.edge_enc = nn.Dense(self.hidden_dim, use_bias=True, name="edge_enc")

        self.layers = [
            CompleteGraphLayer(
                n_max=self.n_max,
                node_dim=self.hidden_dim,
                edge_dim=self.hidden_dim,
                hidden_dim=self.hidden_dim,
                msg_dim=self.msg_dim,
                num_mlp_hidden_layers=self.num_mlp_hidden_layers,
                exclude_self=self.exclude_self,
                residual=self.residual,
                name=f"layer_{i}",
            )
            for i in range(self.num_layers)
        ]

        # Decode node embeddings
        self.node_dec = nn.Dense(self.out_dim, use_bias=True, name="node_dec")

    def _make_node_mask(
        self,
        n_nodes: Union[int, Array],
        *,
        dtype: jnp.dtype,
        n_max: int,
    ) -> Array:
        """
        Builds node_mask of shape (..., n_max) with 1 for indices < n_nodes else 0.
        Supports n_nodes as scalar or batched (...).
        """
        n_nodes_t = jnp.asarray(n_nodes)
        idx = jnp.arange(n_max)
        mask = (idx < n_nodes_t[..., None]).astype(dtype)  # (..., n_max)
        return mask

    def __call__(
        self,
        node_x: Array,
        edge_x: Array,
        n_nodes: Union[int, Array],
    ) -> Tuple[Array, Array]:
        """
        node_x: (..., N=n_max, k)
        edge_x: (..., N=n_max, N=n_max, l)
        n_nodes: (...) number of real nodes (<= n_max)
        """
        dtype = node_x.dtype
        N = node_x.shape[-2]

        node_mask = self._make_node_mask(n_nodes, dtype=dtype, n_max=N)  # (..., N)

        # Encode, then hard-zero padded nodes/edges so biases don't leak
        h = self.node_enc(node_x)                     # (..., N, hidden)
        h = h * node_mask[..., None]                 # (..., N, hidden)

        e = self.edge_enc(edge_x)                    # (..., N, N, hidden)
        pair_mask = node_mask[..., :, None] * node_mask[..., None, :]  # (..., N, N)
        e = e * pair_mask[..., None]                 # (..., N, N, hidden)

        for layer in self.layers:
            h = layer(h, e, node_mask)

        node_out = self.node_dec(h)                  # (..., N, out_dim)
        node_out = node_out * node_mask[..., None]   # padded nodes exactly zero

        # Graph embedding via masked sum pooling over nodes (axis=-2)
        graph_out = node_out.sum(axis=-2)            # (..., out_dim)

        return node_out, graph_out


import jax
import jax.numpy as jnp
from flax import linen as nn
from typing import Tuple

# import your JAX/Flax CompleteGraphGNN from the previous translation
# from complete_graph_gnn_jax import CompleteGraphGNN


Array = jax.Array


class FourierEmbedder(nn.Module):
    scale_min: float
    scale_max: float
    n_fourier: int

    def setup(self):
        # log-spaced frequencies: exp(linspace(log(min), log(max), n))
        self.w = jnp.exp(
            jnp.linspace(jnp.log(self.scale_min), jnp.log(self.scale_max), self.n_fourier)
        )  # (n_fourier,)

    def __call__(self, x: Array) -> Array:
        # x: (...,)
        wx = x[..., None] * self.w  # (..., n_fourier)
        emb = jnp.concatenate([jnp.sin(wx), jnp.cos(wx)], axis=-1)  # (..., 2*n_fourier)
        return emb


class MoleculePINN(nn.Module):
    n_max: int = 29
    n_fourier: int = 32
    r_fourier_min: float = 0.1
    r_fourier_max: float = 1.0
    t_fourier_min: float = 0.1
    t_fourier_max: float = 1.0
    apply_log: bool = False 

    def setup(self):
        nf = self.n_fourier
        self.r_embedder = FourierEmbedder(self.r_fourier_min, self.r_fourier_max, nf)
        self.t_embedder = FourierEmbedder(self.t_fourier_min, self.t_fourier_max, nf)

        self.gnn = CompleteGraphGNN(
            n_max=self.n_max,
            node_in_dim=4 * nf,
            edge_in_dim=4 * nf,
            num_layers=2,
            hidden_dim=64,
            # (hidden_dim/out_dim left as defaults, like your torch code)
        )

    def __call__(self, x: Array, y: Array, t: Array, N: Array) -> Array:
        """
        x: (..., N, d)
        y: (..., N, d)
        t: (...)  (scalar per example; can be batched)
        returns: (...) scalar per example
        """
        # r: (..., N, N)
        diff = x[..., :, None, :] - y[..., None, :, :]
        r = jnp.sum(diff ** 2 ,axis=-1)
        if self.apply_log:
            r = jnp.log(r+1e-8)  # add small constant to avoid log(0)
            t = jnp.log(t+1e-8)

        # embeddings
        r_embed = self.r_embedder(r)  # (..., N, N, 2*nf)
        t_embed = self.t_embedder(t)  # (..., 2*nf)

        # node features: diag of r_embed gives (..., 2*nf, N) -> swap -> (..., N, 2*nf)
        diag = jnp.diagonal(r_embed, axis1=-3, axis2=-2)  # (..., 2*nf, N)
        diag = jnp.swapaxes(diag, -1, -2)                 # (..., N, 2*nf)

        # broadcast t to nodes/edges
        t_node = jnp.broadcast_to(
            t_embed[..., None, :],  # (..., 1, 2*nf)
            diag.shape[:-1] + (t_embed.shape[-1],),  # (..., N, 2*nf)
        )
        t_edge = jnp.broadcast_to(
            t_embed[..., None, None, :],  # (..., 1, 1, 2*nf)
            r_embed.shape[:-1] + (t_embed.shape[-1],),  # (..., N, N, 2*nf)
        )

        node_feat = jnp.concatenate([diag, t_node], axis=-1)     # (..., N, 4*nf)
        edge_feat = jnp.concatenate([r_embed, t_edge], axis=-1)  # (..., N, N, 4*nf)

        # Use all nodes as valid (like passing n_max in your torch call)
        node_out, graph_vec = self.gnn(node_feat, edge_feat, n_nodes=N)  # graph_vec (..., out_dim)

        # Your torch code effectively produced a scalar (because it summed everything).
        # To match that usage (so gradients look like (N,d)), reduce graph_vec -> scalar:
        u = graph_vec.sum(axis=-1)  # (...) scalar per example
        return u
