#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Multigraph denoiser with directional edge features and noise/angle conditioning.
"""

import dataclasses
from typing import Any, Callable, Dict, Optional, Sequence

import chex
import flax.nnx as nnx
import jax
import jax.numpy as jnp

from gen_da import (
    deep_typed_graph_net,
    model_utils,
    typed_graph,
)

# ------------------------------------------------------------------------- #
# Encoders                                                                  #
# ------------------------------------------------------------------------- #
class FourierFeaturesMLP(nnx.Module):
    def __init__(
        self,
        base_period: float,
        num_frequencies: int,
        output_sizes: Sequence[int],
        apply_log_first: bool = False,
        w_init: Optional[nnx.Initializer] = None,
        activation: Callable = jax.nn.gelu,
        rngs: nnx.Rngs = nnx.Rngs(0),
        mesh=None,
        **mlp_kwargs,
    ):
        self.base_period = base_period
        self.num_frequencies = num_frequencies
        self.apply_log_first = apply_log_first
        self.activation = activation

        if w_init is None:
            w_init = nnx.initializers.variance_scaling(
                2.0, mode="fan_in", distribution="uniform"
            )

        in_ch = 2 * num_frequencies
        self.linears = []
        for out_ch in output_sizes:
            if mesh is not None:
                from jax.sharding import PartitionSpec as P
                kernel_init = nnx.with_partitioning(w_init, P(None, "model"))
                bias_init = nnx.with_partitioning(
                    nnx.initializers.zeros_init(), P("model")
                )
            else:
                kernel_init = w_init
                bias_init = nnx.initializers.zeros_init()

            lin = nnx.Linear(
                in_features=in_ch,
                out_features=out_ch,
                kernel_init=kernel_init,
                bias_init=bias_init,
                rngs=rngs,
                **mlp_kwargs,
            )
            setattr(self, f"linear_{len(self.linears)}", lin)
            self.linears.append(lin)
            in_ch = out_ch

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        if self.apply_log_first:
            x = jnp.log(x)
        feats = model_utils.fourier_features(
            x, self.base_period, self.num_frequencies
        )
        for i, lin in enumerate(self.linears):
            feats = lin(feats)
            if i < len(self.linears) - 1:
                feats = self.activation(feats)
        return feats


class AngleHarmonicsMLP(nnx.Module):
    def __init__(
        self,
        num_harmonics: int,
        output_sizes: Sequence[int],
        w_init: Optional[nnx.Initializer] = None,
        activation: Callable = jax.nn.gelu,
        rngs: nnx.Rngs = nnx.Rngs(0),
        mesh=None,
        **mlp_kwargs,
    ):
        self.num_harmonics = num_harmonics
        self.activation = activation
        if w_init is None:
            w_init = nnx.initializers.variance_scaling(
                2.0, mode="fan_in", distribution="uniform"
            )
        in_ch = 2 * num_harmonics  # [sin kθ, cos kθ]
        self.linears = []
        for out_ch in output_sizes:
            if mesh is not None:
                from jax.sharding import PartitionSpec as P
                kernel_init = nnx.with_partitioning(w_init, P(None, "model"))
                bias_init = nnx.with_partitioning(
                    nnx.initializers.zeros_init(), P("model")
                )
            else:
                kernel_init = w_init
                bias_init = nnx.initializers.zeros_init()
            lin = nnx.Linear(
                in_ch,
                out_ch,
                kernel_init=kernel_init,
                bias_init=bias_init,
                rngs=rngs,
                **mlp_kwargs,
            )
            setattr(self, f"linear_{len(self.linears)}", lin)
            self.linears.append(lin)
            in_ch = out_ch

    def __call__(self, angle_rad: jnp.ndarray) -> jnp.ndarray:
        ks = jnp.arange(1, self.num_harmonics + 1, dtype=angle_rad.dtype)
        ktheta = angle_rad[:, None] * ks[None, :]
        feats = jnp.concatenate([jnp.sin(ktheta), jnp.cos(ktheta)], axis=-1)
        for i, lin in enumerate(self.linears):
            feats = lin(feats)
            if i < len(self.linears) - 1:
                feats = self.activation(feats)
        return feats

class ConditioningFusion(nnx.Module):
    """Fuse noise + angle embeddings into a shared conditioning vector."""
    def __init__(self,
                 out_dim: int,
                 hidden_mult: int = 2,
                 activation: Callable = jax.nn.silu,
                 rngs: nnx.Rngs = nnx.Rngs(0),
                 mesh=None):
        hdim = hidden_mult * out_dim
        self.net = nnx.Sequential(
            nnx.Linear(out_dim*2, hdim, rngs=rngs),
            activation,
            nnx.Linear(hdim, out_dim, rngs=rngs),
        )

    def __call__(self, noise_emb: jnp.ndarray, angle_emb: jnp.ndarray):
        fused = jnp.concatenate([noise_emb, angle_emb], axis=-1)
        return self.net(fused)

# ------------------------------------------------------------------------- #
# Configs                                                                   #
# ------------------------------------------------------------------------- #
@chex.dataclass(frozen=True)
class NoiseEncoderConfig:
    apply_log_first: bool = True
    base_period: float = 16.0
    num_frequencies: int = 32
    output_sizes: tuple[int, int] = (64, 64)


@chex.dataclass(frozen=True)
class AngleEncoderConfig:
    num_harmonics: int = 4
    output_sizes: tuple[int, int] = (64, 64)


@chex.dataclass(frozen=True)
class DenoiserArchitectureConfig:
    latent_size: int = 128
    hidden_layers: int = 1
    grid2mesh_aggregate_normalization: Optional[float] = None
    node_output_size: Optional[int] = None


# ------------------------------------------------------------------------- #
# Top-level Denoiser                                                        #
# ------------------------------------------------------------------------- #
class Denoiser(nnx.Module):
    def __init__(
        self,
        noise_encoder_config: Optional[NoiseEncoderConfig],
        angle_encoder_config: Optional[AngleEncoderConfig],
        denoiser_architecture_config: DenoiserArchitectureConfig,
        rngs: nnx.Rngs,
        mesh,
        example_graph_structures: Dict[str, jnp.ndarray],
        target_channels: int,
        guiding_channels: int,
        precomputed_adj_mat: Optional[Any] = None,
    ):
        if noise_encoder_config is None:
            noise_encoder_config = NoiseEncoderConfig()
        if angle_encoder_config is None:
            angle_encoder_config = AngleEncoderConfig()

        self.noise_level_encoder = FourierFeaturesMLP(
            rngs=rngs, mesh=mesh, **dataclasses.asdict(noise_encoder_config)
        )
        # Angle encoder kept for potential future use, but not used as global conditioning
        self.angle_encoder = AngleHarmonicsMLP(
            rngs=rngs, mesh=mesh, **dataclasses.asdict(angle_encoder_config)
        )

        self.fuser = ConditioningFusion(
                out_dim=noise_encoder_config.output_sizes[-1],  # match noise embedding size
                rngs=rngs,
                mesh=mesh,
            )


        self.predictor = DenoiserArchitecture(
            denoiser_architecture_config=denoiser_architecture_config,
            rngs=rngs,
            mesh=mesh,
            precomputed_adj_mat=precomputed_adj_mat,
            example_graph_structures=example_graph_structures,
            target_channels=target_channels,
            guiding_channels=guiding_channels,
        )

    def __call__(
        self,
        noisy_inputs: jnp.ndarray,  # (B, N_o, C_t)
        noise_levels: jnp.ndarray,  # (B,)
        forcings: Dict[str, Any],
        **kwargs,
    ) -> jnp.ndarray:
        if noise_levels.ndim != 1:
            raise ValueError("noise_levels must be 1D [B].")
        # Compute noise encoding; angle is only used for directional features
        noise_enc = self.noise_level_encoder(noise_levels)       # (B,Dn)
        angle_rad = forcings["angle_deg"] * (jnp.pi/180.0)       # (B,)
        angle_enc = self.angle_encoder(angle_rad)                # (B,Da)

        # Fuse into one conditioning vector
        cond = self.fuser(noise_enc, angle_enc)

        return self.predictor(
            noisy_inputs=noisy_inputs,
            noise_level_cond=cond,
            forcings=forcings,
            **kwargs,
        )



# ------------------------------------------------------------------------- #
# Architecture                                                              #
# ------------------------------------------------------------------------- #
class DenoiserArchitecture(nnx.Module):
    def __init__(
        self,
        denoiser_architecture_config: DenoiserArchitectureConfig,
        rngs: nnx.Rngs,
        mesh,
        precomputed_adj_mat: Optional[Any],
        example_graph_structures: Dict[str, jnp.ndarray],
        target_channels: int,
        guiding_channels: int,
    ):
        self.cfg = denoiser_architecture_config
        self.rngs = rngs
        self.mesh = mesh
        self.precomputed_adj_mat = precomputed_adj_mat

        # Edge keys
        self.oo_key = typed_graph.EdgeSetKey("o2o", ("orig", "orig"))
        self.o2r_key = typed_graph.EdgeSetKey("o2r", ("orig", "red"))
        self.rr_key = typed_graph.EdgeSetKey("r2r", ("red", "red"))
        self.r2o_key = typed_graph.EdgeSetKey("r2o", ("red", "orig"))

        # Feature dims from template
        tmpl = example_graph_structures
        orig_struct_dim = tmpl["original_coordinates"].shape[1]
        red_struct_dim = tmpl["reduced_coordinates"].shape[1]

        base_o2o = tmpl.get("o2o_features", None)
        base_o2o_dim = base_o2o.shape[1] if base_o2o is not None else 0
        # +2 for directional edge features (p_par_norm, p_perp_norm)
        o2o_edge_dim = base_o2o_dim + 2

        o2r_edge_dim = tmpl["o2r_features"].shape[1]
        rr_edge = tmpl.get("r2r_features", None)
        r2r_edge_dim = rr_edge.shape[1] if rr_edge is not None else 0
        r2o_edge_dim = tmpl["r2o_features"].shape[1]

        cfg = self.cfg
        C_t = target_channels
        C_g = guiding_channels
        node_out = cfg.node_output_size or C_t

        # ------------------------------------------------------------------ #
        # Graph templates — minimal (N=1) but correct feature dims
        # ------------------------------------------------------------------ #
        def zeros_nf(d):
            return jnp.zeros((1, 1, d), jnp.float32)

        # +3 for node types one-hot (assumed 3) ; +2 for [cosθ,sinθ] ; +Da for angle encoder injected at node level
        # We don't know Da until runtime; allocate a placeholder and concatenate dynamically below
        base_ori2ori_in_dim = C_t + C_g + orig_struct_dim + 3 + 2 + 1 # obs_mask

        graph_o2o_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={
                "orig": typed_graph.NodeSet(
                    n_node=jnp.array([1]), features=zeros_nf(base_ori2ori_in_dim)
                )
            },
            edges={
                self.oo_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([1]),
                    indices=typed_graph.EdgesIndices(
                        senders=jnp.array([0]), receivers=jnp.array([0])
                    ),
                    features=zeros_nf(o2o_edge_dim),
                )
            },
        )

        graph_o2r_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={
                "orig": typed_graph.NodeSet(
                    n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)
                ),
                "red": typed_graph.NodeSet(
                    n_node=jnp.array([1]), features=zeros_nf(red_struct_dim)
                ),
            },
            edges={
                self.o2r_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([1]),
                    indices=typed_graph.EdgesIndices(
                        senders=jnp.array([0]), receivers=jnp.array([0])
                    ),
                    features=zeros_nf(o2r_edge_dim),
                )
            },
        )

        graph_r2r_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={
                "red": typed_graph.NodeSet(
                    n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)
                )
            },
            edges={
                self.rr_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([1]),
                    indices=typed_graph.EdgesIndices(
                        senders=jnp.array([0]), receivers=jnp.array([0])
                    ),
                    features=zeros_nf(r2r_edge_dim) if r2r_edge_dim > 0 else None,
                )
            },
        )

        graph_r2o_t = typed_graph.TypedGraph(
            context=typed_graph.Context(n_graph=jnp.array([1]), features=None),
            nodes={
                "red": typed_graph.NodeSet(
                    n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)
                ),
                "orig": typed_graph.NodeSet(
                    n_node=jnp.array([1]), features=zeros_nf(cfg.latent_size)
                ),
            },
            edges={
                self.r2o_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([1]),
                    indices=typed_graph.EdgesIndices(
                        senders=jnp.array([0]), receivers=jnp.array([0])
                    ),
                    features=zeros_nf(r2o_edge_dim),
                )
            },
        )

        # Store templates
        self.graph_o2o_template = graph_o2o_t
        self.graph_o2r_template = graph_o2r_t
        self.graph_r2r_template = graph_r2r_t
        self.graph_r2o_template = graph_r2o_t

        # ------------------------------------------------------------------ #
        # Instantiate GNNs
        # ------------------------------------------------------------------ #
        print("Init Original→Original GNN")
        self.ori2ori_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            edge_latent_size={"o2o": cfg.latent_size},
            embed_edges=True,
            embed_nodes=True,
            f32_aggregation=True,
            include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size,
            mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"orig": cfg.latent_size},
            node_output_size=None,
            num_message_passing_steps=2,
            use_layer_norm=True,
            use_norm_conditioning=True,
            rngs=self.rngs,
            mesh=self.mesh,
            graph_template=graph_o2o_t,
        )

        print("Init Original→Reduced GNN")
        self.ori2red_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            aggregate_normalization=cfg.grid2mesh_aggregate_normalization,
            edge_latent_size={"o2r": cfg.latent_size},
            embed_edges=True,
            embed_nodes=True,
            f32_aggregation=True,
            include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size,
            mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"orig": cfg.latent_size, "red": cfg.latent_size},
            node_output_size=None,
            num_message_passing_steps=1,
            use_layer_norm=True,
            use_norm_conditioning=True,
            rngs=self.rngs,
            mesh=self.mesh,
            graph_template=graph_o2r_t,
        )

        print("Init Reduced→Reduced GNN")
        self.red2red_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            edge_latent_size={"r2r": cfg.latent_size},
            embed_edges=True,
            embed_nodes=False,
            f32_aggregation=True,
            include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size,
            mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"red": cfg.latent_size},
            node_output_size=None,
            num_message_passing_steps=6,
            use_layer_norm=True,
            use_norm_conditioning=True,
            rngs=self.rngs,
            mesh=self.mesh,
            graph_template=graph_r2r_t,
        )

        print("Init Reduced→Original GNN")
        self.red2ori_gnn = deep_typed_graph_net.DeepTypedGraphNet(
            activation="swish",
            edge_latent_size={"r2o": cfg.latent_size},
            embed_edges=True,
            embed_nodes=False,
            f32_aggregation=True,
            include_sent_messages_in_node_update=False,
            mlp_hidden_size=cfg.latent_size,
            mlp_num_hidden_layers=cfg.hidden_layers,
            node_latent_size={"red": cfg.latent_size, "orig": cfg.latent_size},
            node_output_size={"orig": node_out},
            num_message_passing_steps=1,
            use_layer_norm=True,
            use_norm_conditioning=True,
            rngs=self.rngs,
            mesh=self.mesh,
            graph_template=graph_r2o_t,
        )

    # --------------------- helpers --------------------- #
    @staticmethod
    def _broadcast(arr: Optional[jnp.ndarray], B: int) -> Optional[jnp.ndarray]:
        if arr is None:
            return None
        return jnp.broadcast_to(arr[:, None, :], (arr.shape[0], B, arr.shape[1]))

    @staticmethod
    def _ensure_same_shapes(graphs: Sequence[Dict[str, jnp.ndarray]]):
        base = graphs[0]
        for g in graphs[1:]:
            for k, v in base.items():
                if v.shape != g[k].shape:
                    raise AssertionError(
                        f"Batched samples must share the same topology. Mismatch in '{k}'."
                    )

    # --------------------- forward --------------------- #
    def __call__(
        self,
        noisy_inputs: jnp.ndarray,
        noise_level_cond: jnp.ndarray,
        forcings: Dict[str, Any],
        **kwargs,
    ) -> jnp.ndarray:

        B, N_o, C_t = noisy_inputs.shape

        gstructs = forcings["graph_structures"]
        if isinstance(gstructs, dict):
            gstructs = [gstructs] * B
        self._ensure_same_shapes(gstructs)
        g = gstructs[0]

        # Node/edge arrays
        orig_xy = jnp.asarray(g["original_coordinates"], jnp.float32)  # (N_o,2)
        red_xy = jnp.asarray(g["reduced_coordinates"], jnp.float32)    # (N_r,2)

        node_types_int = jnp.asarray(g["node_types"], jnp.int32)        # {1,2,3}
        node_types_oh  = jax.nn.one_hot(jnp.clip(node_types_int - 1, 0, 2), 3).astype(jnp.float32)


        oo_s = jnp.asarray(g["o2o_senders"], jnp.int32)
        oo_r = jnp.asarray(g["o2o_receivers"], jnp.int32)
        base_oo_f = (
            jnp.asarray(g.get("o2o_features", None), jnp.float32)
            if g.get("o2o_features", None) is not None
            else None
        )

        o2r_s = jnp.asarray(g["o2r_senders"], jnp.int32)
        o2r_r = jnp.asarray(g["o2r_receivers"], jnp.int32)
        o2r_f = jnp.asarray(g["o2r_features"], jnp.float32)

        rr_s = jnp.asarray(g["r2r_senders"], jnp.int32)
        rr_r = jnp.asarray(g["r2r_receivers"], jnp.int32)
        rr_f = (
            jnp.asarray(g.get("r2r_features", None), jnp.float32)
            if g.get("r2r_features", None) is not None
            else None
        )

        r2o_s = jnp.asarray(g["r2o_senders"], jnp.int32)
        r2o_r = jnp.asarray(g["r2o_receivers"], jnp.int32)
        r2o_f = jnp.asarray(g["r2o_features"], jnp.float32)

        N_o_struct = orig_xy.shape[0]
        if N_o_struct != N_o:
            raise ValueError(
                f"Input CFD nodes ({N_o}) != original_coordinates ({N_o_struct})"
            )

        stat = {
            "orig_struct": self._broadcast(orig_xy, B),        # (N_o,B,2)
            "red_struct": self._broadcast(red_xy, B),          # (N_r,B,2)
            "node_types": self._broadcast(node_types_oh, B),   # (N_o,B,3)
            "o2o_edges": self._broadcast(base_oo_f, B) if base_oo_f is not None else None,
            "o2r_edges": self._broadcast(o2r_f, B),
            "r2r_edges": self._broadcast(rr_f, B) if rr_f is not None else None,
            "r2o_edges": self._broadcast(r2o_f, B),
        }

        # Build wind direction per-batch for directional features and for node features
        theta_deg = jnp.asarray(forcings["angle_deg"], jnp.float32)  # (B,)
        theta_rad = theta_deg * (jnp.pi / 180.0)
        d = jnp.stack([jnp.cos(theta_rad), jnp.sin(theta_rad)], axis=-1)  # (B,2)
        d_perp = jnp.stack([-d[:, 1], d[:, 0]], axis=-1)                  # (B,2)
        dir_nodes = jnp.broadcast_to(d[None, :, :], (N_o, B, 2))          # (N_o,B,2)

        # ---------------- Directional EDGE FEATURES for o2o ---------------- #
        rel = orig_xy[oo_r] - orig_xy[oo_s]                               # (E,2)
        length = jnp.linalg.norm(rel, axis=-1, keepdims=True) + 1e-6      # (E,1)
        p_par = jnp.einsum("ed,bd->eb", rel, d)                           # (E,B)
        p_perp = jnp.einsum("ed,bd->eb", rel, d_perp)                     # (E,B)
        p_par_n = p_par / length                                          # (E,B)
        p_perp_n = p_perp / length                                        # (E,B)
        dir_edge = jnp.stack([p_par_n, p_perp_n], axis=-1)                # (E,B,2)



        if stat["o2o_edges"] is None:
            o2o_edge_feats = dir_edge
        else:
            o2o_edge_feats = jnp.concatenate([stat["o2o_edges"], dir_edge], axis=-1)

        # ---------------- Original → Original ---------------- #
        # --- before building features, pull the observations ---
        obs_mask = jnp.asarray(forcings.get("obs_mask", None), jnp.float32)  # (B, N_o)
        if obs_mask is None:
            obs_mask = jnp.zeros((B, N_o), jnp.float32)
        obs_mask = obs_mask[..., None]  # (B, N_o, 1)

        # Optional guiding field → use it for observed values (already supports shape (B,N_o,Cg))
        guiding = jnp.asarray(
            forcings.get("U_field_guiding", jnp.zeros((B, N_o, 0), jnp.float32))
        ).astype(jnp.float32)
        if guiding.ndim == 2:
            guiding = guiding[..., None]

        # ...
        # ---------------- Original → Original ---------------- #
        orig_feat = noisy_inputs.transpose(1, 0, 2).astype(jnp.float32)  # (N_o,B,C_t)
        if guiding.shape[2] == 0:
            prev_feat = jnp.zeros((N_o, B, 0), jnp.float32)
        else:
            prev_feat = guiding.transpose(1, 0, 2).astype(jnp.float32)    # (N_o,B,C_g)

        obs_mask_node = obs_mask.transpose(1, 0, 2).astype(jnp.float32)    # (N_o,B,1)

        # Combine features: [noisy/latent target | observed values | coords | onehot types | wind2 | obs_mask]
        combined = jnp.concatenate(
            [
                orig_feat,
                prev_feat,
                stat["orig_struct"],     # (N_o,B,2)
                stat["node_types"],      # (N_o,B,3)
                dir_nodes,               # (N_o,B,2)
                obs_mask_node,           # (N_o,B,1)
            ],
            axis=-1,
        )


        graph1 = self.graph_o2o_template._replace(
            nodes={"orig": typed_graph.NodeSet(n_node=jnp.array([N_o]), features=combined)},
            edges={
                self.oo_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([oo_s.size]),
                    indices=typed_graph.EdgesIndices(senders=oo_s, receivers=oo_r),
                    features=o2o_edge_feats,
                )
            },
        )

        # Global conditioning: ONLY noise encoding
        gcond = noise_level_cond.astype(jnp.float32)
        out_oo = self.ori2ori_gnn(graph1, global_norm_conditioning=gcond)
        updated_orig = out_oo.nodes["orig"].features.astype(jnp.float32)  # (N_o,B,L)

        # ---------------- Original → Reduced ------------------ #
        N_r = red_xy.shape[0]
        graph2 = self.graph_o2r_template._replace(
            nodes={
                "orig": typed_graph.NodeSet(n_node=jnp.array([N_o]), features=updated_orig),
                "red": typed_graph.NodeSet(n_node=jnp.array([N_r]), features=stat["red_struct"]),
            },
            edges={
                self.o2r_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([o2r_s.size]),
                    indices=typed_graph.EdgesIndices(senders=o2r_s, receivers=o2r_r),
                    features=stat["o2r_edges"],
                )
            },
        )
        out_o2r = self.ori2red_gnn(graph2, global_norm_conditioning=gcond)
        latent_red = out_o2r.nodes["red"].features.astype(jnp.float32)     # (N_r,B,L)

        # ---------------- Reduced → Reduced -------------------- #
        graph3 = self.graph_r2r_template._replace(
            nodes={"red": typed_graph.NodeSet(n_node=jnp.array([N_r]), features=latent_red)},
            edges={
                self.rr_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([rr_s.size]),
                    indices=typed_graph.EdgesIndices(senders=rr_s, receivers=rr_r),
                    features=stat["r2r_edges"],
                )
            },
        )
        out_rr = self.red2red_gnn(graph3, global_norm_conditioning=gcond)
        updated_red = out_rr.nodes["red"].features.astype(jnp.float32)     # (N_r,B,L)

        # ---------------- Reduced → Original ------------------- #
        graph4 = self.graph_r2o_template._replace(
            nodes={
                "red": typed_graph.NodeSet(n_node=jnp.array([N_r]), features=updated_red),
                "orig": typed_graph.NodeSet(n_node=jnp.array([N_o]), features=updated_orig),
            },
            edges={
                self.r2o_key: typed_graph.EdgeSet(
                    n_edge=jnp.array([r2o_s.size]),
                    indices=typed_graph.EdgesIndices(senders=r2o_s, receivers=r2o_r),
                    features=stat["r2o_edges"],
                )
            },
        )
        out_r2o = self.red2ori_gnn(graph4, global_norm_conditioning=gcond)

        pred = out_r2o.nodes["orig"].features.astype(jnp.float32)  # (N_o,B,C_out)
        return pred.transpose(1, 0, 2)  # (B,N_o,C_out)
