# SPDX-License-Identifier: MIT
# JAX/Flax port of your SPD-Net (PyTorch) architecture.

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Tuple

from jax import config
config.update("jax_enable_x64", True)  # you used torch.double everywhere

import jax
import jax.numpy as jnp
from flax import linen as nn


def symmetric(a: jnp.ndarray) -> jnp.ndarray:
    """0.5*(A + A^T) for the last two axes."""
    return 0.5 * (a + jnp.swapaxes(a, -1, -2))


# -------------------------
# SPDRectified: custom VJP
# -------------------------

@jax.custom_vjp
def spd_rectify_one(x: jnp.ndarray, epsilon: jnp.ndarray) -> jnp.ndarray:
    """
    Rectify one SPD matrix by clamping eigenvalues: U diag(max(s, eps)) U^T.

    x: (N, N)
    epsilon: scalar array
    """
    # Using eigh (symmetric eigendecomposition). For SPD matrices this matches your intent.
    s, u = jnp.linalg.eigh(x)
    s_clamped = jnp.maximum(s, epsilon)
    # u @ diag(s_clamped) @ u^T == (u * s_clamped) @ u^T (scale columns)
    y = (u * s_clamped) @ u.T
    return y


def _spd_rectify_one_fwd(x: jnp.ndarray, epsilon: jnp.ndarray):
    s, u = jnp.linalg.eigh(x)
    s_clamped = jnp.maximum(s, epsilon)
    y = (u * s_clamped) @ u.T
    # Save what we need for a backward that matches your PyTorch code structure
    return y, (u, s, s_clamped, epsilon)


def _spd_rectify_one_bwd(res, g: jnp.ndarray):
    """
    Backward that mirrors your PyTorch SPDRectifiedFunction.backward derivation.

    Returns grads w.r.t (x, epsilon). We set grad_epsilon = 0.
    """
    u, s, s_clamped, epsilon = res

    # In your code: g = symmetric(g)
    g = symmetric(g)

    # max_mask = s > epsilon
    max_mask = s > epsilon
    mask_f = max_mask.astype(u.dtype)

    # s_max_diag uses clamped eigenvalues (epsilon for small ones)
    # u @ diag(s_clamped) == u * s_clamped
    u_smax = u * s_clamped

    # dLdV = 2*(g.mm(u.mm(s_max_diag)))
    # Here: dLdV = 2 * (g @ (u @ diag(s_clamped))) = 2 * (g @ (u * s_clamped))
    dLdV = 2.0 * (g @ u_smax)

    # dLdS = I * (Q @ (u^T g u))
    # Q is diag(mask), I * (...) keeps only diagonal entries
    ut_g_u = u.T @ g @ u
    dLdS_diag = mask_f * jnp.diag(ut_g_u)
    dLdS = jnp.diag(dLdS_diag)

    # P = 1/(s_i - s_j), with 0 where denom == 0
    diff = s[:, None] - s[None, :]
    P = jnp.where(jnp.abs(diff) == 0, 0.0, 1.0 / diff)  # matches your mask_zero logic

    # symmetric(P.t() * u.t().mm(dLdV))
    ut_dLdV = u.T @ dLdV
    A = P.T * ut_dLdV
    A = symmetric(A)

    grad_x = u @ (A + dLdS) @ u.T

    # No grad for epsilon (like your backward returns None for epsilon)
    grad_eps = jnp.zeros_like(epsilon)

    return (grad_x, grad_eps)


spd_rectify_one.defvjp(_spd_rectify_one_fwd, _spd_rectify_one_bwd)

def frob_normalize(A: jnp.ndarray, eps: float = 1e-6) -> jnp.ndarray:
    """Scale A so ||A||_F = sqrt(m). Works with leading dims (..., m, m)."""
    m = A.shape[-1]
    frob = jnp.linalg.norm(A, axis=(-2, -1), keepdims=True)
    return A * (jnp.sqrt(m) / (frob + eps))

# class SPDRectified(nn.Module):
#     epsilon: float = 1e-4
#     dtype: jnp.dtype = jnp.float64

#     def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
#         eps = jnp.asarray(self.epsilon, dtype=self.dtype)
#         # vmap over batch: x is (B, N, N)
#         return jax.vmap(lambda m: spd_rectify_one(m, eps))(x.astype(self.dtype))
class SPDRectified(nn.Module):
    epsilon: float = 1e-4
    dtype: jnp.dtype = jnp.float64

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(self.dtype)
        eps = jnp.asarray(self.epsilon, dtype=self.dtype)

        if x.ndim == 2:
            return spd_rectify_one(x, eps)  # (n,n) -> (n,n)
        elif x.ndim == 3:
            return jax.vmap(spd_rectify_one, in_axes=(0, None))(x, eps)  # (B,n,n)
        else:
            raise ValueError(f"SPDRectified expects (n,n) or (B,n,n), got {x.shape}")


# -------------------------
# SPDIncreaseDim
# -------------------------

# class SPDIncreaseDim(nn.Module):
#     input_size: int
#     output_size: int
#     dtype: jnp.dtype = jnp.float64

#     def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
#         """
#         Matches PyTorch:
#           output = add + eye @ input @ eye^T
#         where output has input in top-left and identity on added dims.
#         """
#         in_n = int(self.input_size)
#         out_n = int(self.output_size)
#         x = x.astype(self.dtype)

#         eye = jnp.eye(out_n, in_n, dtype=self.dtype)  # (out, in)
#         add_diag = jnp.concatenate(
#             [jnp.zeros((in_n,), dtype=self.dtype),
#              jnp.ones((out_n - in_n,), dtype=self.dtype)],
#             axis=0,
#         )
#         add = jnp.diag(add_diag)  # (out, out)

#         # Broadcasted matmul over batch:
#         # (out,in) @ (B,in,in) -> (B,out,in)
#         # (B,out,in) @ (in,out) -> (B,out,out)
#         return add[None, :, :] + (eye @ x) @ eye.T

class SPDIncreaseDim(nn.Module):
    input_size: int
    output_size: int
    dtype: jnp.dtype = jnp.float64

    def setup(self):
        in_n = int(self.input_size)
        out_n = int(self.output_size)

        self.eye = jnp.eye(out_n, in_n, dtype=self.dtype)  # (out, in)

        add_diag = jnp.concatenate(
            [jnp.zeros((in_n,), dtype=self.dtype),
             jnp.ones((out_n - in_n,), dtype=self.dtype)],
            axis=0,
        )
        self.add = jnp.diag(add_diag)  # (out, out)

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(self.dtype)
        # add broadcasts over any leading dims automatically
        return self.add + (self.eye @ x) @ self.eye.T


# -------------------------
# SPDTransform
# -------------------------

# class SPDTransform(nn.Module):
#     input_size: int
#     output_size: int
#     time_dim: int
#     y_dim: int  # kept for signature parity; your PyTorch forward doesn't use Y here
#     dtype: jnp.dtype = jnp.float64

#     def setup(self):
#         in_n = int(self.input_size)
#         out_n = int(self.output_size)

#         if out_n > in_n:
#             self.increase_dim = SPDIncreaseDim(in_n, out_n, dtype=self.dtype)
#             weight_in = out_n
#         else:
#             self.increase_dim = None
#             weight_in = in_n

#         # Stiefel-like weight, orthogonal init (like nn.init.orthogonal_)
#         self.weight = self.param(
#             "weight",
#             nn.initializers.orthogonal(),
#             (weight_in, out_n),
#             self.dtype,
#         )

#         # emb_layer: Linear(time_dim -> out_n*out_n)
#         self.emb_layer = nn.Dense(
#             out_n * out_n,
#             dtype=self.dtype,
#             param_dtype=self.dtype,
#         )

#         # In your PyTorch code Y_layer exists but is unused in forward.
#         # If you want it in params anyway, you'd need to call it somewhere.
#         self.Y_layer = nn.Dense(
#             out_n * out_n,
#             dtype=self.dtype,
#             param_dtype=self.dtype,
#         )

#     def __call__(self, x: jnp.ndarray, t: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
#         x = x.astype(self.dtype)
#         t = t.astype(self.dtype)
#         # Y unused here (matches your code)

#         if self.increase_dim is not None:
#             x = self.increase_dim(x)

#         W = self.weight  # (in, out) or (out, out)
#         # W^T X W (broadcasted over batch)
#         x = (W.T @ (x @ W))

#         # Time embedding -> (B, out*out) -> reshape to (B, out, out)
#         emb = self.emb_layer(t)
#         B = x.shape[0]
#         m = x.shape[-1]
#         t_emb = emb.reshape((B, m, m))

#         x = t_emb @ x @ jnp.swapaxes(t_emb, -1, -2)
#         return x
# class SPDTransform(nn.Module):
#     input_size: int
#     output_size: int
#     time_dim: int
#     y_dim: int
#     dtype: jnp.dtype = jnp.float64

#     def setup(self):
#         in_n = int(self.input_size)
#         out_n = int(self.output_size)

#         if out_n > in_n:
#             self.increase_dim = SPDIncreaseDim(in_n, out_n, dtype=self.dtype)
#             weight_in = out_n
#         else:
#             self.increase_dim = None
#             weight_in = in_n

#         self.weight = self.param(
#             "weight",
#             nn.initializers.orthogonal(),
#             (weight_in, out_n),
#             self.dtype,
#         )

#         self.emb_layer = nn.Dense(out_n * out_n, dtype=self.dtype, param_dtype=self.dtype)

#         # exists in your PyTorch, but unused in forward; keep only if you need it
#         self.Y_layer = nn.Dense(out_n * out_n, dtype=self.dtype, param_dtype=self.dtype)

#     def __call__(self, x: jnp.ndarray, t_emb: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
#         x = x.astype(self.dtype)
#         t_emb = t_emb.astype(self.dtype)

#         if self.increase_dim is not None:
#             x = self.increase_dim(x)

#         W = self.weight
#         x = W.T @ (x @ W)  # works for (n,n) and (B,n,n) by broadcasting

#         m = x.shape[-1]
#         emb = self.emb_layer(t_emb)  # (..., m*m)
#         t_mat = emb.reshape(emb.shape[:-1] + (m, m))  # (..., m, m)

#         return t_mat @ x @ jnp.swapaxes(t_mat, -1, -2)


class SPDTransform(nn.Module):
    input_size: int
    output_size: int
    time_dim: int
    y_dim: int
    dtype: jnp.dtype = jnp.float64

    cond_scale: float = 1e-2          # <-- NEW (try 1e-2, 1e-3)
    normalize_cond: bool = True       # <-- NEW

    def setup(self):
        in_n = int(self.input_size)
        out_n = int(self.output_size)

        if out_n > in_n:
            self.increase_dim = SPDIncreaseDim(in_n, out_n, dtype=self.dtype)
            weight_in = out_n
        else:
            self.increase_dim = None
            weight_in = in_n

        self.weight = self.param(
            "weight", nn.initializers.orthogonal(), (weight_in, out_n), self.dtype
        )

        # Make this small so Δ starts small
        self.emb_layer = nn.Dense(
            out_n * out_n,
            dtype=self.dtype,
            param_dtype=self.dtype,
            kernel_init=nn.initializers.normal(1e-3),
            bias_init=nn.initializers.zeros,
        )

    def __call__(self, x: jnp.ndarray, t_emb: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(self.dtype)
        t_emb = t_emb.astype(self.dtype)

        if self.increase_dim is not None:
            x = self.increase_dim(x)

        W = self.weight
        x = W.T @ (x @ W)

        m = x.shape[-1]
        delta = self.emb_layer(t_emb).reshape(t_emb.shape[:-1] + (m, m))
        E = jnp.eye(m, dtype=self.dtype) + self.cond_scale * delta  # <-- KEY CHANGE

        if self.normalize_cond:
            E = frob_normalize(E)

        return E @ x @ jnp.swapaxes(E, -1, -2)

# -------------------------
# SPDTransform1
# -------------------------

# class SPDTransform1(nn.Module):
#     input_size: int
#     output_size: int
#     time_dim: int  # kept for parity; your PyTorch forward doesn't use emb_layer(t)
#     y_dim: int
#     dtype: jnp.dtype = jnp.float64

#     def setup(self):
#         in_n = int(self.input_size)
#         out_n = int(self.output_size)

#         if out_n > in_n:
#             self.increase_dim = SPDIncreaseDim(in_n, out_n, dtype=self.dtype)
#             weight_in = out_n
#         else:
#             self.increase_dim = None
#             weight_in = in_n

#         self.weight = self.param(
#             "weight",
#             nn.initializers.orthogonal(),
#             (weight_in, out_n),
#             self.dtype,
#         )

#         # emb_layer exists in PyTorch but is commented out in forward
#         self.t_dense1 = nn.Dense(self.time_dim, dtype=self.dtype, param_dtype=self.dtype)
#         self.t_dense2 = nn.Dense(out_n * 2, dtype=self.dtype, param_dtype=self.dtype)

#         # Y_layer MLP: Y_dim -> out*out -> SiLU -> out*out
#         self.y_dense1 = nn.Dense(out_n * out_n, dtype=self.dtype, param_dtype=self.dtype)
#         self.y_dense2 = nn.Dense(out_n * out_n, dtype=self.dtype, param_dtype=self.dtype)

#     def __call__(self, x: jnp.ndarray, t: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
#         x = x.astype(self.dtype)
#         Y = Y.astype(self.dtype)

#         # emb = self.emb_layer(t)  # (unused in your PyTorch forward)

#         emb_Y = self.y_dense2(jax.nn.silu(self.y_dense1(Y)))  # (B, out*out)

#         if self.increase_dim is not None:
#             x = self.increase_dim(x)

#         W = self.weight
#         x = (W.T @ (x @ W))

#         B = x.shape[0]
#         m = x.shape[-1]
#         emb_Y = emb_Y.reshape((B, m, m))

#         x = emb_Y @ x @ jnp.swapaxes(emb_Y, -1, -2)
#         return x
# class SPDTransform1(nn.Module):
#     input_size: int
#     output_size: int
#     time_dim: int
#     y_dim: int
#     dtype: jnp.dtype = jnp.float64

#     def setup(self):
#         in_n = int(self.input_size)
#         out_n = int(self.output_size)

#         if out_n > in_n:
#             self.increase_dim = SPDIncreaseDim(in_n, out_n, dtype=self.dtype)
#             weight_in = out_n
#         else:
#             self.increase_dim = None
#             weight_in = in_n

#         self.weight = self.param(
#             "weight",
#             nn.initializers.orthogonal(),
#             (weight_in, out_n),
#             self.dtype,
#         )

#         self.y_dense1 = nn.Dense(out_n * out_n, dtype=self.dtype, param_dtype=self.dtype)
#         self.y_dense2 = nn.Dense(out_n * out_n, dtype=self.dtype, param_dtype=self.dtype)

#     def __call__(self, x: jnp.ndarray, t_emb: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
#         x = x.astype(self.dtype)
#         Y = Y.astype(self.dtype)

#         emb_Y = self.y_dense2(jax.nn.silu(self.y_dense1(Y)))  # (..., m*m)

#         if self.increase_dim is not None:
#             x = self.increase_dim(x)

#         W = self.weight
#         x = W.T @ (x @ W)

#         m = x.shape[-1]
#         y_mat = emb_Y.reshape(emb_Y.shape[:-1] + (m, m))  # (..., m, m)

#         return y_mat @ x @ jnp.swapaxes(y_mat, -1, -2)


class SPDTransform1(nn.Module):
    input_size: int
    output_size: int
    time_dim: int
    y_dim: int
    dtype: jnp.dtype = jnp.float64

    cond_scale: float = 1e-2
    normalize_cond: bool = True

    def setup(self):
        in_n = int(self.input_size)
        out_n = int(self.output_size)

        if out_n > in_n:
            self.increase_dim = SPDIncreaseDim(in_n, out_n, dtype=self.dtype)
            weight_in = out_n
        else:
            self.increase_dim = None
            weight_in = in_n

        self.weight = self.param(
            "weight", nn.initializers.orthogonal(), (weight_in, out_n), self.dtype
        )

        # Make these small so the produced Δ starts small
        self.y_dense1 = nn.Dense(
            out_n * out_n,
            dtype=self.dtype,
            param_dtype=self.dtype,
            kernel_init=nn.initializers.normal(1e-3),
            bias_init=nn.initializers.zeros,
        )
        self.y_dense2 = nn.Dense(
            out_n * out_n,
            dtype=self.dtype,
            param_dtype=self.dtype,
            kernel_init=nn.initializers.normal(1e-3),
            bias_init=nn.initializers.zeros,
        )

    def __call__(self, x: jnp.ndarray, t_emb: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
        x = x.astype(self.dtype)
        Y = Y.astype(self.dtype)

        if self.increase_dim is not None:
            x = self.increase_dim(x)

        W = self.weight
        x = W.T @ (x @ W)

        m = x.shape[-1]
        delta = self.y_dense2(jax.nn.silu(self.y_dense1(Y))).reshape(Y.shape[:-1] + (m, m))
        E = jnp.eye(m, dtype=self.dtype) + self.cond_scale * delta  # <-- KEY CHANGE

        if self.normalize_cond:
            E = frob_normalize(E)

        return E @ x @ jnp.swapaxes(E, -1, -2)


# -------------------------
# SPD_NET (U-Net style)
# -------------------------


class FourierTimeEmbedding(nn.Module):
    """
    Fourier embedding with fixed frequencies:
      w = linspace(0, fourier_scale, n_fourier+1)[1:]
      emb(t) = concat(sin(w t), cos(w t))

    Notes:
      - t can be shape (...,) or (..., 1). Output is (..., 2*n_fourier).
      - If you want 2π scaling, change wt = (2*jnp.pi) * t[...,None] * w[None,:].
    """
    n_fourier: int
    fourier_scale: float
    dtype: jnp.dtype = jnp.float32
    fourier_log_scale: bool = False  # if True, use log-spaced frequencies
    log_t: bool = False               # if True, embed log(t) instead of t

    def setup(self):
        if self.fourier_log_scale:
            w_np = jnp.logspace(
                0.0,
                jnp.log10(float(self.fourier_scale)),
                int(self.n_fourier),
                dtype=jnp.float32,
            )
        else:
            w_np = jnp.linspace(
                0.0, float(self.fourier_scale), int(self.n_fourier) + 1, dtype=jnp.float32
            )[1:]  # drop 0 frequency
        self.w = jnp.asarray(w_np, dtype=self.dtype)  # non-trainable constant

    def __call__(self, t: jnp.array) -> jnp.array:
        t = jnp.asarray(t, dtype=self.dtype)
        if self.log_t:
            t = jnp.log(t + 1e-10)

        # Enforce (..., 1)
        if t.ndim < 1 or t.shape[-1] != 1:
            raise ValueError(f"FourierTimeEmbedding expects t shape (..., 1), got {t.shape}")

        t = jnp.squeeze(t, axis=-1)             # (...,)
        wt = t[..., None] * self.w     # (..., n_fourier)
        return jnp.concatenate([jnp.sin(wt), jnp.cos(wt)], axis=-1) 


def normalize_spd(A, eps = 1e-3):
    w, Q = jnp.linalg.eigh(A)
    w_log = jnp.log(w)
    norm = jnp.linalg.norm(w_log)
    w_log_normalized = w_log / (norm+eps)
    w_normalized = jnp.exp(w_log_normalized)
    A_normalized = (Q * w_normalized[..., None, :]) @ jnp.swapaxes(Q, -1, -2)
    return norm, A_normalized


class SPDNet(nn.Module):
    spd_size: int
    y_dim: int
    dtype: jnp.dtype = jnp.float64
    epsilon: float = 1e-4
    n_fourier: int = 64
    fourier_scale: float = 16.0
    spd_normalize: bool = False
    def setup(self):
        td = int(self.n_fourier * 2)
        yd = int(self.y_dim)
        self.time_emb = FourierTimeEmbedding(
            n_fourier=self.n_fourier,
            fourier_scale=self.fourier_scale,
            fourier_log_scale=False,
            log_t=True,
            dtype=self.dtype,
        )
        if self.spd_normalize:
            self.spd_norm_emb = FourierTimeEmbedding(
                n_fourier=self.n_fourier,
                fourier_scale=self.fourier_scale,
                fourier_log_scale=False,
                log_t=True,
                dtype=self.dtype,
            )
            td += int(self.n_fourier * 2)

        self.trans1   = SPDTransform(self.spd_size, 10, td, yd, dtype=self.dtype)
        self.trans1_5 = SPDTransform1(10, 10, td, yd, dtype=self.dtype)
        self.trans2   = SPDTransform(10, 10, td, yd, dtype=self.dtype)
        self.trans2_5 = SPDTransform1(10, 10, td, yd, dtype=self.dtype)

        self.trans3   = SPDTransform(10, 6, td, yd, dtype=self.dtype)
        self.trans3_5 = SPDTransform1(6,  6, td, yd, dtype=self.dtype)

        self.trans4   = SPDTransform(6,  6, td, yd, dtype=self.dtype)
        self.trans4_5 = SPDTransform1(6,  6, td, yd, dtype=self.dtype)

        self.trans5   = SPDTransform(6,  3, td, yd, dtype=self.dtype)
        self.trans5_5 = SPDTransform1(3,  3, td, yd, dtype=self.dtype)

        self.trans5_8 = SPDTransform(3,  3, td, yd, dtype=self.dtype)
        self.trans5_9 = SPDTransform1(3,  3, td, yd, dtype=self.dtype)

        self.trans6   = SPDTransform(3,  6, td, yd, dtype=self.dtype)
        self.trans6_5 = SPDTransform1(6,  6, td, yd, dtype=self.dtype)

        self.trans7   = SPDTransform(6,  6, td, yd, dtype=self.dtype)
        self.trans7_5 = SPDTransform1(6,  6, td, yd, dtype=self.dtype)

        self.trans8   = SPDTransform(6,  10, td, yd, dtype=self.dtype)
        self.trans8_5 = SPDTransform1(10, 10, td, yd, dtype=self.dtype)

        self.trans9   = SPDTransform(10, 10, td, yd, dtype=self.dtype)
        self.trans9_5 = SPDTransform1(10, 10, td, yd, dtype=self.dtype)

        self.trans10  = SPDTransform(10, 10, td, yd, dtype=self.dtype)

        # Rectifiers (stateless but kept as separate instances like your code)
        self.rect1   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect1_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect2   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect2_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect3   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect3_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect4   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect4_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5_8 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5_9 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect6   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect6_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect7   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect7_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect8   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect8_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect9   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect9_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)


    def unet_forward(self, x: jnp.ndarray, t_emb: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
        r = 0.5

        x1 = self.rect1(self.trans1(x, t_emb, Y))
        x1_5 = self.rect1_5(self.trans1_5(x1, t_emb, Y))

        x2 = self.rect2(self.trans2(x1_5, t_emb, Y))
        x2_5 = self.rect2_5(self.trans2_5(x2, t_emb, Y))

        x3 = self.rect3(self.trans3(x2_5, t_emb, Y))
        x3_5 = self.rect3_5(self.trans3_5(x3, t_emb, Y))

        x4 = self.rect4(self.trans4(x3_5, t_emb, Y))
        x4_5 = self.rect4_5(self.trans4_5(x4, t_emb, Y))

        x5 = self.rect5(self.trans5(x4_5, t_emb, Y))
        x5_5 = self.rect5_5(self.trans5_5(x5, t_emb, Y))

        x5_8 = self.rect5_8(self.trans5_8(x5_5, t_emb, Y))
        x5_9 = self.rect5_9(self.trans5_9(x5_8, t_emb, Y))

        x6 = (r * x5_9 + (1.0 - r) * x5_5) / 2.0
        x6 = self.rect6(self.trans6(x6, t_emb, Y))
        x6_5 = self.rect6_5(self.trans6_5(x6, t_emb, Y))

        x7 = (r * x6_5 + (1.0 - r) * x4_5) / 2.0
        x7 = self.rect7(self.trans7(x7, t_emb, Y))
        x7_5 = self.rect7_5(self.trans7_5(x7, t_emb, Y))

        x8 = (r * x7_5 + (1.0 - r) * x3_5) / 2.0
        x8 = self.rect8(self.trans8(x8, t_emb, Y))
        x8_5 = self.rect8_5(self.trans8_5(x8, t_emb, Y))

        x9 = (r * x8_5 + (1.0 - r) * x2_5) / 2.0
        x9 = self.rect9(self.trans9(x9, t_emb, Y))
        x9_5 = self.rect9_5(self.trans9_5(x9, t_emb, Y))

        x_out = self.trans10(x9_5, t_emb, Y)
        return jnp.squeeze(x_out)

    def __call__(self, x: jnp.ndarray, t: jnp.ndarray, condition: jnp.ndarray) -> jnp.ndarray:
        """
        x: (B, spd_size, spd_size)
        t: (B,) or (B,1)
        condition: (B, y_dim)
        """
        x = x.astype(self.dtype)
        condition = condition.astype(self.dtype)

        t = t.reshape((-1, 1)).astype(self.dtype)

        t_emb = self.time_emb(t)
        if self.spd_normalize:
            norm, x = normalize_spd(x)
            norm_emb = self.spd_norm_emb(norm.reshape((-1, 1)).astype(self.dtype))
            t_emb = jnp.concatenate([t_emb, norm_emb], axis=-1)


        return self.unet_forward(x, t_emb, condition)
    
class SPDNetNDim(nn.Module):
    spd_size: int
    y_dim: int
    dtype: jnp.dtype = jnp.float64
    epsilon: float = 1e-4
    n_fourier: int = 64
    fourier_scale: float = 16.0
    spd_normalize: bool = False

    # Optional overrides (leave None to use the defaults below)
    mid_dim: Optional[int] = None
    bottleneck_dim: Optional[int] = None

    def setup(self):
        td = int(self.n_fourier * 2)
        yd = int(self.y_dim)

        self.time_emb = FourierTimeEmbedding(
            n_fourier=self.n_fourier,
            fourier_scale=self.fourier_scale,
            fourier_log_scale=False,
            log_t=True,
            dtype=self.dtype,
        )

        if self.spd_normalize:
            self.spd_norm_emb = FourierTimeEmbedding(
                n_fourier=self.n_fourier,
                fourier_scale=self.fourier_scale,
                fourier_log_scale=False,
                log_t=True,
                dtype=self.dtype,
            )
            td += int(self.n_fourier * 2)

        # -----------------------------
        # Dimension schedule
        # L = full SPD size (input/output)
        # M = mid resolution
        # S = bottleneck resolution
        # -----------------------------
        L = int(self.spd_size)

        if self.mid_dim is not None and self.bottleneck_dim is not None:
            M = int(self.mid_dim)
            S = int(self.bottleneck_dim)
        else:
            # Tuned to match your original 10 -> 6 -> 3 ladder
            if L == 10:
                M, S = 6, 3
            elif L == 13:
                M, S = 8, 4
            elif L == 15:
                M, S = 9, 5
            else:
                # Reasonable fallback if you ever use another size
                M = int(round(0.6 * L))
                S = int(round(0.3 * L))
                M = max(3, min(M, L))
                S = max(2, min(S, M - 1))

        # (Optional) store for debugging
        self.L, self.M, self.S = L, M, S

        # -----------------------------
        # U-Net blocks (same topology as your original)
        # Replace (10,6,3) with (L,M,S)
        # -----------------------------
        self.trans1   = SPDTransform(L, L, td, yd, dtype=self.dtype)
        self.trans1_5 = SPDTransform1(L, L, td, yd, dtype=self.dtype)
        self.trans2   = SPDTransform(L, L, td, yd, dtype=self.dtype)
        self.trans2_5 = SPDTransform1(L, L, td, yd, dtype=self.dtype)

        self.trans3   = SPDTransform(L, M, td, yd, dtype=self.dtype)
        self.trans3_5 = SPDTransform1(M, M, td, yd, dtype=self.dtype)

        self.trans4   = SPDTransform(M, M, td, yd, dtype=self.dtype)
        self.trans4_5 = SPDTransform1(M, M, td, yd, dtype=self.dtype)

        self.trans5   = SPDTransform(M, S, td, yd, dtype=self.dtype)
        self.trans5_5 = SPDTransform1(S, S, td, yd, dtype=self.dtype)

        self.trans5_8 = SPDTransform(S, S, td, yd, dtype=self.dtype)
        self.trans5_9 = SPDTransform1(S, S, td, yd, dtype=self.dtype)

        self.trans6   = SPDTransform(S, M, td, yd, dtype=self.dtype)
        self.trans6_5 = SPDTransform1(M, M, td, yd, dtype=self.dtype)

        self.trans7   = SPDTransform(M, M, td, yd, dtype=self.dtype)
        self.trans7_5 = SPDTransform1(M, M, td, yd, dtype=self.dtype)

        self.trans8   = SPDTransform(M, L, td, yd, dtype=self.dtype)
        self.trans8_5 = SPDTransform1(L, L, td, yd, dtype=self.dtype)

        self.trans9   = SPDTransform(L, L, td, yd, dtype=self.dtype)
        self.trans9_5 = SPDTransform1(L, L, td, yd, dtype=self.dtype)

        self.trans10  = SPDTransform(L, L, td, yd, dtype=self.dtype)

        # Rectifiers (unchanged)
        self.rect1   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect1_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect2   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect2_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect3   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect3_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect4   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect4_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5_8 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect5_9 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect6   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect6_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect7   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect7_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect8   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect8_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect9   = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)
        self.rect9_5 = SPDRectified(epsilon=self.epsilon, dtype=self.dtype)

    def unet_forward(self, x: jnp.ndarray, t_emb: jnp.ndarray, Y: jnp.ndarray) -> jnp.ndarray:
        r = 0.5

        x1 = self.rect1(self.trans1(x, t_emb, Y))
        x1_5 = self.rect1_5(self.trans1_5(x1, t_emb, Y))

        x2 = self.rect2(self.trans2(x1_5, t_emb, Y))
        x2_5 = self.rect2_5(self.trans2_5(x2, t_emb, Y))

        x3 = self.rect3(self.trans3(x2_5, t_emb, Y))
        x3_5 = self.rect3_5(self.trans3_5(x3, t_emb, Y))

        x4 = self.rect4(self.trans4(x3_5, t_emb, Y))
        x4_5 = self.rect4_5(self.trans4_5(x4, t_emb, Y))

        x5 = self.rect5(self.trans5(x4_5, t_emb, Y))
        x5_5 = self.rect5_5(self.trans5_5(x5, t_emb, Y))

        x5_8 = self.rect5_8(self.trans5_8(x5_5, t_emb, Y))
        x5_9 = self.rect5_9(self.trans5_9(x5_8, t_emb, Y))

        x6 = (r * x5_9 + (1.0 - r) * x5_5) / 2.0
        x6 = self.rect6(self.trans6(x6, t_emb, Y))
        x6_5 = self.rect6_5(self.trans6_5(x6, t_emb, Y))

        x7 = (r * x6_5 + (1.0 - r) * x4_5) / 2.0
        x7 = self.rect7(self.trans7(x7, t_emb, Y))
        x7_5 = self.rect7_5(self.trans7_5(x7, t_emb, Y))

        x8 = (r * x7_5 + (1.0 - r) * x3_5) / 2.0
        x8 = self.rect8(self.trans8(x8, t_emb, Y))
        x8_5 = self.rect8_5(self.trans8_5(x8, t_emb, Y))

        x9 = (r * x8_5 + (1.0 - r) * x2_5) / 2.0
        x9 = self.rect9(self.trans9(x9, t_emb, Y))
        x9_5 = self.rect9_5(self.trans9_5(x9, t_emb, Y))

        x_out = self.trans10(x9_5, t_emb, Y)
        return jnp.squeeze(x_out)

    def __call__(self, x: jnp.ndarray, t: jnp.ndarray, condition: jnp.ndarray) -> jnp.ndarray:
        """
        x: (B, spd_size, spd_size)
        t: (B,) or (B,1)
        condition: (B, y_dim)
        """
        x = x.astype(self.dtype)
        condition = condition.astype(self.dtype)
        t = t.reshape((-1, 1)).astype(self.dtype)

        t_emb = self.time_emb(t)

        if self.spd_normalize:
            norm, x = normalize_spd(x)
            norm_emb = self.spd_norm_emb(norm.reshape((-1, 1)).astype(self.dtype))
            t_emb = jnp.concatenate([t_emb, norm_emb], axis=-1)

        return self.unet_forward(x, t_emb, condition)