from typing import Tuple, Union, NamedTuple, Mapping, Optional

import jax.numpy as jnp
import jax.random
import numpy as np
from flax import linen as nn
from scipy.spatial import Delaunay

from ol.graph.graphbuilder import GraphSet, TypedGraph
from ol.models.common import AbstractOperator, Inputs, LeadTimeConditionedNorm, FeedForward, segment_softmax
from ol.models.extender import CrossAttentionExtender
from ol.utils import Array, shuffle_arrays


class AGNO(nn.Module):
  output_size: int
  mlp_hidden_layers: int = 3
  mlp_hidden_size: int = 64
  use_layer_norm: bool = True
  conditioned_normalization: bool = True
  cond_norm_hidden_size: int = 4
  sorted_per_receiver: bool = False

  def setup(self):
    mlp_layer_sizes = [self.mlp_hidden_size]*self.mlp_hidden_layers + [self.output_size]
    self.ff = FeedForward(
      layer_sizes=mlp_layer_sizes,
      use_layer_norm=self.use_layer_norm,
      use_conditional_norm=self.conditioned_normalization,
      cond_norm_hidden_size=self.cond_norm_hidden_size,
      activation=nn.gelu,
    )
    self.query_proj = nn.Dense(features=self.mlp_hidden_size)
    self.key_proj = nn.Dense(features=self.mlp_hidden_size)
    self.scaling_factor = 1.0 / (self.mlp_hidden_size ** 0.5)

  def __call__(self, src: Array, rcv: Array, edg_indices: Array, weights: Array = None, condition: Array = None, deterministic: bool = False):
    # Gather sender and receiver features
    batched_index = jax.vmap(lambda f, idx: f[idx])
    src_indices, rcv_indices = edg_indices[:, :, 0], edg_indices[:, :, 1]
    edg_src = batched_index(src, src_indices)  # [bsz, num_edg, dim_src]
    edg_rcv = batched_index(rcv, rcv_indices)  # [bsz, num_edg, dim_rcv]
    edg = jnp.concatenate([edg_src, edg_rcv], axis=-1)  # [bsz, num_edg, dim_rcv+dim_src]

    # Prepare aggregation weights
    if weights is None:
      key = self.key_proj(edg_src)  # [bsz, num_edg, dim_attn]
      query = self.query_proj(edg_rcv)  # [bsz, num_edg, dim_attn]
      attention_scores = jnp.sum(query * key, axis=-1) * self.scaling_factor  # [bsz, num_edg]
      weights = jax.vmap(segment_softmax, in_axes=(0, 0, None, None))(attention_scores, rcv_indices, rcv.shape[1], self.sorted_per_receiver)  # [bsz, num_rcv]
    else:
      # Make sure weights sum to one per receiver
      weights /= jax.vmap(jax.ops.segment_sum, in_axes=(0, 0, None, None))(weights, rcv_indices, rcv.shape[1], self.sorted_per_receiver)[rcv_indices]  # [bsz, num_rcv]

    # Get weighted edge features
    edg = jnp.expand_dims(weights, axis=-1) * self.ff(edg, c=condition)  # [bsz, num_edg, output_size]
    # Aggregate edge features (messages)
    out_features = jax.vmap(jax.ops.segment_sum, in_axes=(0, 0, None, None))(edg, rcv_indices, rcv.shape[1], self.sorted_per_receiver)  # [bsz, num_rcv, output_size]

    return out_features

class Encoder(nn.Module):
  latent_size: int
  output_size: int
  mlp_hidden_layers: int = 1
  use_layer_norm: bool = True
  conditioned_normalization: bool = True
  cond_norm_hidden_size: int = 4
  p_edge_masking: float = .0
  dropout: float = .0

  def setup(self):
    self.lifting_pnodes = FeedForward(
      conv=True,
      layer_sizes=[self.latent_size],
      activation=nn.gelu,
      use_layer_norm=self.use_layer_norm,
      use_conditional_norm=False,
      dropout=self.dropout,
    )
    self.lifting_rnodes = FeedForward(
      conv=True,
      layer_sizes=[self.latent_size],
      activation=nn.gelu,
      use_layer_norm=self.use_layer_norm,
      use_conditional_norm=False,
      dropout=self.dropout,
    )
    self.agno = AGNO(
      output_size=self.output_size,
      mlp_hidden_layers=self.mlp_hidden_layers,
      mlp_hidden_size=self.latent_size,
      use_layer_norm=self.use_layer_norm,
      conditioned_normalization=self.conditioned_normalization,
      cond_norm_hidden_size=self.cond_norm_hidden_size,
      sorted_per_receiver=True,
    )

  def __call__(self,
    graph: TypedGraph,
    input_pnode_features: Array,
    input_rnode_features: Array,
    tau: Union[None, float],
    deterministic: bool = False,
  ) -> tuple[Array, Array]:
    """Runs the p2r GNN, extracting latent physical and regional nodes."""

    # Concatenate node structural features with input features
    pnodes = jnp.concatenate([input_pnode_features, graph.nodes['pnodes'].features], axis=-1)
    rnodes = jnp.concatenate([input_rnode_features, graph.nodes['rnodes'].features], axis=-1)
    # Get edges
    p2r_edges_key = graph.edge_key_by_name('p2r')
    edges = graph.edges[p2r_edges_key]

    # Drop out edges randomly with the given probability
    if deterministic:
      n_edges_after = edges.features.shape[1]
      edge_senders = edges.indices.senders
      edge_receivers = edges.indices.receivers
    else:
      rngkey = self.make_rng('masking')
      n_edges_after = int((1 - self.p_edge_masking) * edges.features.shape[1])
      [edge_senders, edge_receivers] = shuffle_arrays(
        rngkey=rngkey, arrays=[edges.indices.senders, edges.indices.receivers], axis=1)
      edge_senders = edge_senders[:, :n_edges_after]
      edge_receivers = edge_receivers[:, :n_edges_after]

    # Lift the features
    pnodes = self.lifting_pnodes(pnodes, c=tau, deterministic=deterministic)
    rnodes = self.lifting_rnodes(rnodes, c=tau, deterministic=deterministic)

    # Sort on ascending receiver indices and feed to AGNO
    order = jnp.argsort(edge_receivers, axis=1)
    edge_senders = jnp.take_along_axis(edge_senders, order, axis=1)
    edge_receivers = jnp.take_along_axis(edge_receivers, order, axis=1)
    indices = jnp.stack([edge_senders, edge_receivers], axis=-1)
    rnodes = self.agno(src=pnodes, rcv=rnodes, edg_indices=indices, weights=None, condition=tau, deterministic=deterministic)

    return rnodes, pnodes

class Decoder(nn.Module):
  latent_size: int
  output_size: int
  mlp_hidden_layers: int = 1
  use_layer_norm: bool = True
  conditioned_normalization: bool = True
  cond_norm_hidden_size: int = 4
  p_edge_masking: float = .0
  dropout: float = .0

  def setup(self):
    self.agno = AGNO(
      output_size=self.latent_size,
      mlp_hidden_layers=self.mlp_hidden_layers,
      mlp_hidden_size=self.latent_size,
      use_layer_norm=self.use_layer_norm,
      conditioned_normalization=self.conditioned_normalization,
      cond_norm_hidden_size=self.cond_norm_hidden_size,
      sorted_per_receiver=True,
    )
    self.projection = FeedForward(
      conv=True,
      layer_sizes=[self.output_size],
      activation=nn.gelu,
      use_layer_norm=False,
      use_conditional_norm=False,
      dropout=self.dropout,
    )

  def __call__(self,
    graph: TypedGraph,
    rnode_features: Array,
    pnode_features: Array,
    tau: Union[None, float],
    deterministic: bool = False,
  ) -> tuple[Array, Array]:
    """Runs the p2r GNN, extracting latent physical and regional nodes."""

    rnodes = rnode_features
    pnodes = pnode_features
    # Get edges
    r2p_edges_key = graph.edge_key_by_name('r2p')
    edges = graph.edges[r2p_edges_key]

    # Drop out edges randomly with the given probability
    if deterministic:
      n_edges_after = edges.features.shape[1]
      edge_senders = edges.indices.senders
      edge_receivers = edges.indices.receivers
    else:
      rngkey = self.make_rng('masking')
      n_edges_after = int((1 - self.p_edge_masking) * edges.features.shape[1])
      [edge_senders, edge_receivers] = shuffle_arrays(
        rngkey=rngkey, arrays=[edges.indices.senders, edges.indices.receivers], axis=1)
      edge_senders = edge_senders[:, :n_edges_after]
      edge_receivers = edge_receivers[:, :n_edges_after]

    # Feed to AGNO and project to output dimension
    # Sort edges by receiver index so AGNO's per-receiver aggregation groups messages.
    order = jnp.argsort(edge_receivers, axis=1)
    edge_senders = jnp.take_along_axis(edge_senders, order, axis=1)
    edge_receivers = jnp.take_along_axis(edge_receivers, order, axis=1)
    indices = jnp.stack([edge_senders, edge_receivers], axis=-1)
    pnodes_update = self.agno(src=rnodes, rcv=pnodes, edg_indices=indices, weights=None, condition=tau, deterministic=deterministic)
    if pnodes.shape[-1] == self.latent_size:
      pnodes += pnodes_update
    else:
      pnodes = pnodes_update
    pnodes = self.projection(pnodes, c=tau, deterministic=deterministic)

    return pnodes

class GroupQueryFlashAttention(nn.Module):
  num_heads: int
  head_dim: int
  use_conditional_norm: bool = False
  cond_norm_hidden_size: int = 4

  def setup(self):
    self.q_proj = nn.Dense(features=(self.num_heads * self.head_dim), use_bias=False)
    self.k_proj = nn.Dense(features=(self.num_heads * self.head_dim), use_bias=False)
    self.v_proj = nn.Dense(features=(self.num_heads * self.head_dim), use_bias=False)

  @nn.compact
  def __call__(self, x, condition: Optional[float] = None):
    input_size = x.shape[-1]

    if self.use_conditional_norm:
      x = LeadTimeConditionedNorm(self.cond_norm_hidden_size, x.shape[-1])(c=condition, x=x)

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    batch_size, seq_len, _ = q.shape

    # NOTE: The input dimensions of jax.nn.dot_product_attention are different from torch.nn.functional.scaled_dot_product_attention
    # [batch_size, seq_len, num_heads, head_dim]
    q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
    k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
    v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim)
    x = jax.nn.dot_product_attention(q, k, v)  # [batch_size, seq_len, num_heads, head_dim]
    x = x.reshape(batch_size, seq_len, -1)  # [batch_size, seq_len, num_heads * head_dim]
    x = nn.Dense(input_size, use_bias=False)(x)  # [batch_size, seq_len, input_size]

    return x

class FFN(nn.Module):
  output_size: int
  ffn_hidden_size: int
  use_conditional_norm: bool = False
  cond_norm_hidden_size: int = 4

  @nn.compact
  def __call__(self, x, condition: Optional[float] = None):
    x = nn.Dense(self.output_size, use_bias=False)(nn.silu(nn.Dense(self.ffn_hidden_size, use_bias=False)(x)) * nn.Dense(self.ffn_hidden_size, use_bias=False)(x))
    if self.use_conditional_norm:
      x = LeadTimeConditionedNorm(self.cond_norm_hidden_size, x.shape[-1])(c=condition, x=x)

    return x

class RMSNorm(nn.Module):
  eps: float = 1e-6

  def _norm(self, x):
    x = jnp.astype(x, float)
    return x * jax.lax.rsqrt(jnp.pow(x, 2).mean(-1, keepdims=True) + self.eps)

  @nn.compact
  def __call__(self, x):
    input_shape = x.shape[-1]
    weight = self.param('weights', (lambda rng, shape: jnp.ones(shape)), input_shape)
    output = jnp.astype(self._norm(x), x.dtype)
    return output * weight

class TransformerBlock(nn.Module):
  num_heads: int
  hidden_size: int
  ffn_multiplier: int = 2
  skip_connection: bool = False
  use_attn_norm: bool = True
  use_ffn_norm: bool = True
  use_conditional_norm: bool = False

  def setup(self):
    self.attn = GroupQueryFlashAttention(
      num_heads=self.num_heads,
      head_dim=self.hidden_size,
      use_conditional_norm=self.use_conditional_norm,
      cond_norm_hidden_size=4,
    )

  @nn.compact
  def __call__(self, x, condition=None, skip=None):
    input_shape = x.shape[-1]
    if self.skip_connection and skip is not None:
      x = jnp.concatenate([x, skip], axis=-1)
      x = nn.Dense(input_shape)(x)

    h = x if not self.use_attn_norm else RMSNorm()(x)
    h = x + self.attn(h, condition=condition)
    h = h if not self.use_ffn_norm else RMSNorm()(h)
    out = h + FFN(
      output_size=input_shape, ffn_hidden_size=(self.hidden_size * self.ffn_multiplier),
      use_conditional_norm=self.use_conditional_norm, cond_norm_hidden_size=4,
    )(h, condition=condition)

    return out

class Transformer(nn.Module):
  output_size: int
  patch_size: int
  hidden_size: int
  num_layers: int
  num_heads: int
  use_long_range_skip: bool = True

  def setup(self):
    self.encoder_layers = [TransformerBlock(hidden_size=self.hidden_size, num_heads=self.num_heads, skip_connection=False) for _ in range(self.num_layers // 2)]
    self.middle_layer = TransformerBlock(hidden_size=self.hidden_size, num_heads=self.num_heads, skip_connection=False) if (self.num_layers % 2 == 1) else None
    self.decoder_layers = [TransformerBlock(hidden_size=self.hidden_size, num_heads=self.num_heads, skip_connection=True) for _ in range(self.num_layers // 2)]

  @nn.compact
  def __call__(self, x, condition):
    if x.shape[-1] != self.hidden_size:
      x = nn.Dense(self.hidden_size)(x)
    skips = []

    for layer in self.encoder_layers:
      x = layer(x, condition=condition)
      skips.append(x)

    if self.middle_layer is not None:
      x = self.middle_layer(x, condition=condition)

    for layer in self.decoder_layers:
      skip = skips.pop() if self.use_long_range_skip else None
      x = layer(x, condition=condition, skip=skip)

    if x.shape[-1] != self.output_size:
      x = nn.Dense(self.output_size)(x)

    return x

class Processor(nn.Module):
  gridres: Tuple
  patch_size: int
  hidden_size: int
  num_layers: int
  num_heads: int
  conditioned_normalization: bool = True
  cond_norm_hidden_size: bool = True

  @nn.compact
  def __call__(self,
    rnode_features: Array,
    tau: Union[None, float],
    deterministic: bool = False,
  ) -> Array:

    # Get dimensions
    B, W, H, C = rnode_features.shape
    P = self.patch_size

    # Reshape to patches: [bsz, W, H, C] -> [bsz, num_patches, P*P*C]
    rnode_features = rnode_features.reshape(B, (W // P), P, (H // P), P, C)
    rnode_features = jnp.permute_dims(rnode_features, axes=(0, 1, 3, 2, 4, 5))
    patch_features = rnode_features.reshape(B, (W // P) * (H // P), P * P * C)

    # Linear transformation
    patch_features = nn.Dense(patch_features.shape[-1])(patch_features)

    # Positional embedding
    pos = jnp.stack(jnp.meshgrid(jnp.arange((W // P)), jnp.arange((H // P)), indexing='ij'), axis=-1).reshape(-1, 2).astype(jnp.float32)
    pos_emb = self._compute_absolute_embeddings(pos, (P * P * C))
    pos_emb = jnp.tile(pos_emb[None, :, :], reps=(B, 1, 1))
    patch_features += pos_emb

    # Transformer
    patch_features = Transformer(output_size=(P * P * C), patch_size=self.patch_size, hidden_size=self.hidden_size, num_layers=self.num_layers, num_heads=self.num_heads)(x=patch_features, condition=tau)

    # Reshape back to the original grid: [bsz, num_patches, P*P*C] -> [bsz, W, H, C]
    rnode_features = patch_features.reshape(B, (W // P), (H // P), P, P, C)
    rnode_features = jnp.permute_dims(rnode_features, axes=(0, 1, 3, 2, 4, 5))
    rnode_features = rnode_features.reshape(B, W, H, C)

    return rnode_features

  def _compute_absolute_embeddings(self, positions, embed_dim):
    num_pos_dims = positions.shape[1]
    dim_touse = embed_dim // (2 * num_pos_dims)
    freq_seq = jnp.arange(dim_touse, dtype=jnp.float32)
    inv_freq = 1.0 / (10000 ** (freq_seq / dim_touse))
    sinusoid_inp = positions[:, :, None] * inv_freq[None, None, :]
    pos_emb = jnp.concatenate([jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp)], axis=-1)
    pos_emb = pos_emb.reshape(positions.shape[0], -1)
    return pos_emb

class GAOT(AbstractOperator):
  """A JAX implementation of GAOT."""

  num_outputs: int
  gridres: Tuple
  patch_size: int
  transformer_hidden_size: int = 256
  processor_steps: int = 5
  processor_attn_heads: int = 1
  latent_size: int = 128
  mlp_hidden_layers: int = 1
  p_edge_masking: int = 0.5
  tdep: bool = False

  def setup(self):
    self.variable_mesh = False

    self.encoder = Encoder(
      latent_size=self.latent_size,
      output_size=self.latent_size,
      mlp_hidden_layers=self.mlp_hidden_layers,
      use_layer_norm=True,
      conditioned_normalization=self.tdep,
      cond_norm_hidden_size=4,
      p_edge_masking=self.p_edge_masking,
      dropout=.0,
      name='encoder',
    )

    self.processor = Processor(
      gridres=self.gridres,
      patch_size=self.patch_size,
      hidden_size=self.transformer_hidden_size,
      num_layers=self.processor_steps,
      num_heads=self.processor_attn_heads,
      conditioned_normalization=self.tdep,
      cond_norm_hidden_size=4,
      name='processor',
    )

    self.decoder = Decoder(
      # variable_mesh=self.variable_mesh,
      latent_size=self.latent_size,
      output_size=self.num_outputs,
      mlp_hidden_layers=self.mlp_hidden_layers,
      use_layer_norm=True,
      conditioned_normalization=self.tdep,
      cond_norm_hidden_size=4,
      p_edge_masking=self.p_edge_masking,
      dropout=.0,
      name='decoder',
    )

  @staticmethod
  def _prepare_features(feats: Array) -> Array:
    # Expand time axis
    feats = jnp.expand_dims(feats, axis=1)
    return feats

  def _encode_process_decode(self, graphs: GraphSet, input_pnode_features: Array, input_rnode_features: Array, tau: Union[None, float], deterministic: bool = False) -> Array:

    # Add dummy node features
    dummy_pnode_features = jnp.zeros(shape=(input_pnode_features.shape[0], 1, input_pnode_features.shape[2]))
    input_pnode_features = jnp.concatenate([input_pnode_features, dummy_pnode_features], axis=1)

    # Transfer data from the physical mesh to the regional mesh
    # -> [batch_size, num_nodes, latent_size]
    latent_rnode_features, latent_pnode_features = self.encoder(graphs.p2r, input_pnode_features, input_rnode_features, tau, deterministic=deterministic)
    self.sow(col='intermediates', name='rnodes-encoded', value=self._prepare_features(latent_rnode_features[:, :-1]))

    # Re-arrange rnodes on a grid (excluding the dummy rnode)
    latent_rnode_features_on_grid = latent_rnode_features[:, :-1, :].reshape(latent_rnode_features.shape[0], *self.gridres, latent_rnode_features.shape[2])

    # Run message passing steps in the regional mesh
    # -> [batch_size, num_rnodes, latent_size]
    processed_rnode_features_on_grid = self.processor(latent_rnode_features_on_grid, tau, deterministic=deterministic)
    processed_rnode_features = processed_rnode_features_on_grid.reshape(latent_rnode_features.shape[0], latent_rnode_features.shape[1]-1, latent_rnode_features.shape[2])
    processed_rnode_features = jnp.concatenate([processed_rnode_features, latent_rnode_features[:, [-1], :]], axis=1)
    self.sow(col='intermediates', name='rnodes-processed', value=self._prepare_features(processed_rnode_features[:, :-1]))

    # Transfer data from the regional mesh to the physical mesh
    # -> [batch_size, num_pnodes_out, latent_size]
    output_pnode_features = self.decoder(graphs.r2p, processed_rnode_features, latent_pnode_features, tau, deterministic=deterministic)
    self.sow(col='intermediates', name='pnodes-decoded', value=self._prepare_features(output_pnode_features[:, :-1]))

    # Remove the dummy node features
    output_pnode_features = output_pnode_features[:, :-1, :]

    return output_pnode_features

  def call(self, inputs: Inputs, graphs: GraphSet, input_pnode_features: Array = None, input_rnode_features: Array = None, deterministic: bool = False) -> Array:
    """Inputs must be of shape [batch_size, 1, num_physical_nodes, num_variables]"""

    # Read dimensions
    batch_size = inputs.a.shape[0]
    num_pnodes_inp = inputs.x_inp.shape[2]
    num_pnodes_out = inputs.x_out.shape[2]
    num_rnodes = graphs.p2r.nodes['rnodes'].features.shape[1]

    if self.tdep:
      # Prepare the time channel
      assert inputs.t is not None
      t = jnp.array(inputs.t, dtype=jnp.float32)
      if t.ndim == 4:
        t = t[:, :, 0, 0]
      if t.size == 1:
        t = jnp.tile(t.reshape(1, 1), reps=(batch_size, 1))
      # Prepare the lead time channel
      assert inputs.tau is not None
      tau = jnp.array(inputs.tau, dtype=jnp.float32)
      if tau.ndim == 4:
        tau = tau[:, :, 0, 0]
      if tau.size == 1:
        tau = jnp.tile(tau.reshape(1, 1), reps=(batch_size, 1))
    else:
      tau = None

    # Prepare the physical node features
    # q, u -> [batch_size, num_pnodes_inp, num_inputs]
    pnode_features = inputs.a.squeeze(1)
    # Concatente with forced and input features
    pnode_features_forced = []
    if input_pnode_features is not None:
      pnode_features_forced.append(input_pnode_features)
    if self.tdep:
      pnode_features_forced.append(jnp.tile(jnp.expand_dims(t, axis=1), reps=(1, num_pnodes_inp, 1)))
      pnode_features_forced.append(jnp.tile(jnp.expand_dims(tau, axis=1), reps=(1, num_pnodes_inp, 1)))
    pnode_features = jnp.concatenate([pnode_features, *pnode_features_forced], axis=-1)

    # Prepare the input regional features
    if input_rnode_features is None:
      rnode_features = jnp.zeros(shape=(batch_size, num_rnodes, 1), dtype=inputs.a.dtype)
    else:
      rnode_features = input_rnode_features

    # Pass through the GNNs
    output_pnodes = self._encode_process_decode(graphs=graphs, input_pnode_features=pnode_features, input_rnode_features=rnode_features, tau=tau, deterministic=deterministic)

    # Reshape the output
    # [batch_size, num_pnodes_out, num_outputs] -> [batch_size, 1, num_pnodes_out, num_outputs]
    output = self._prepare_features(output_pnodes)
    self._check_function(output, x=inputs.x_out)

    return output

class XGAOT(AbstractOperator):

  configs_core: Mapping
  configs_extender: Mapping
  use_extender: bool
  independent: bool
  boundary_size: int
  unify: bool = True
  even: bool = False

  def setup(self):
    self.operator = GAOT(**self.configs_core)
    self.extender = CrossAttentionExtender(**self.configs_extender)

  @nn.compact
  def call(self, inputs: Inputs, graphs: GraphSet, deterministic: bool = False):
    """Extended RIGNO"""

    # Unify all the boundary conditions into a Robin form
    q, m  = dict(), dict()
    if self.unify:
      prefixes = {'-'.join(key.split('-')[:2]) for key in inputs.q.keys()}  # NOTE: assuming a structure as "bc-d-typ" in the names
      prefixes = sorted(list(prefixes))  # NOTE: otherwise it leads to random ordering of the prefixes
      for prefix in prefixes:
        m_prefix, g_prefix, alpha_prefix, beta_prefix = self._unify_all_boundary_conditions(
          q={key: val for key, val in inputs.q.items() if key.startswith(prefix)},
          m={key: val for key, val in inputs.m.items() if key.startswith(prefix)},
        )
        q[f'{prefix}-uni'] = jnp.concatenate([g_prefix, alpha_prefix, beta_prefix], axis=-1)
        m[f'{prefix}-uni'] = m_prefix
      # NOTE: Concatenate all and use a single CrossAttentionExtender on them
      # NOTE: Only viable if the masks exactly match
      q = {'bc-uni': jnp.concatenate(list(q.values()), axis=-1)}
      m = {'bc-uni': np.any(jnp.stack(list(m.values()), axis=-1), axis=-1)}
    # Or leave them as they are
    else:
      q, m = inputs.q, inputs.m

    if self.use_extender:
      # Project all the boundary functions (e.g., bc-0-dir, bc-1-dir) separately
      xq = jnp.concatenate([inputs.x_inp, inputs.s], axis=-1)  # Geometric features of the whole domain (to be used for boundaries)
      rnodes = graphs.p2r.nodes['rnodes'].features  # Contains only positional encodings of the rnodes
      if self.independent:
        def extend_group(_q, _m):
          _q_batched_components = jnp.moveaxis(_q, -1, 1).reshape(-1, *_q.shape[1:-1])[..., None]
          _m_batched_components = jnp.repeat(_m, repeats=_q.shape[-1], axis=0)
          _rnodes_batched_components = jnp.repeat(rnodes, repeats=_q.shape[-1], axis=0)
          _xq_batched_components = jnp.repeat(xq, repeats=_q.shape[-1], axis=0)
          _extensions = self._extend(_xq_batched_components, _q_batched_components, _m_batched_components, f_domain=_rnodes_batched_components, deterministic=deterministic)
          _extensions = jnp.moveaxis(_extensions.reshape(_q.shape[0], -1, *_extensions.shape[1:]), 1, -1)
          _extensions = _extensions.reshape(*_extensions.shape[:-2], -1)
          return _extensions
      else:
        def extend_group(_q, _m):
          _extensions = self._extend(xq, _q, _m, f_domain=rnodes, deterministic=deterministic)
          return _extensions
      extensions = jax.tree.map(extend_group, q, m)
      # Concatenate all the extensions
      jax.tree.map_with_path(lambda key, psi: self.sow(col='intermediates', name=f'extensions-{key[0].key}', value=psi), extensions)
      extensions = jnp.concatenate(jax.tree.flatten(extensions)[0], axis=-1)
      # Pass through the operator
      # NOTE: feeding in rnode features
      output = self.operator(inputs, graphs=graphs, input_rnode_features=extensions, deterministic=deterministic)
    else:
      # Add a variable dimension to the masks
      m = jax.tree.map(lambda m: m[..., None], m)
      # Extend the boundary functions with zeros and add a binary mask
      extensions = jax.tree.map(lambda _q, _m: jnp.concatenate([_q, _m.astype(float)], axis=-1), q, m)
      extensions = jnp.concatenate(jax.tree.flatten(extensions)[0], axis=-1)
      extensions = jnp.squeeze(extensions, axis=1)

      # Pass through the operator
      # NOTE: feeding in pnode features
      output = self.operator(inputs, graphs=graphs, input_pnode_features=extensions, deterministic=deterministic)

    return output

  def _unify_all_boundary_conditions(self, q, m):
    """Unifies non-overlapping Dirichlet, Neumann, and Robin BCs into a Robin form."""

    # Get unified mask
    # NOTE: assuming that the masks do not overlap
    m_unified = jnp.any(jnp.stack(jax.tree.leaves(m), axis=-1), axis=-1)

    # Get unified g
    # NOTE: assuming g is always the first function
    gm = jax.tree.map(lambda m, q: m * q[..., 0], m, q)
    g_unified = jnp.stack(jax.tree.leaves(gm), axis=-1).sum(axis=-1, keepdims=True)

    # Get alpha and beta for all BC types
    # NOTE: assuming alpha is only present in Robin BC as the second function
    ones = jax.tree.map(lambda q: jnp.concatenate([q[..., 1:], jnp.ones(shape=(*q.shape[:3], 3-q.shape[3]))], axis=-1), q)  # contains alpha=1 and beta=1
    alpha = jax.tree.map_with_path(lambda key, _ones: jnp.zeros_like(_ones[..., 0]) if (key[0].key.endswith('neu')) else _ones[..., 0], ones)  # alpha=0 for Neumann
    beta = jax.tree.map_with_path(lambda key, _ones: jnp.zeros_like(_ones[..., 1]) if (key[0].key.endswith('dir')) else _ones[..., 1], ones)  # beta=0 for Dirichlet
    # Get unified alpha and beta
    am = jax.tree.map(lambda m, a: m * a, m, alpha)
    bm = jax.tree.map(lambda m, b: m * b, m, beta)
    alpha_unified = jnp.stack(jax.tree.leaves(am), axis=-1).sum(axis=-1, keepdims=True)
    beta_unified = jnp.stack(jax.tree.leaves(bm), axis=-1).sum(axis=-1, keepdims=True)
    # Normalize alpha, beta, and g with respect to the 2-norm of the vector [alpha, beta]
    zeta_norm = jnp.linalg.norm(jnp.stack([alpha_unified, beta_unified]), axis=0)
    zeta_norm = jnp.where(m_unified[:, :, :, None], zeta_norm, 1.)  # avoid division by zero
    alpha_unified /= zeta_norm
    beta_unified /= zeta_norm
    g_unified /= zeta_norm

    return m_unified, g_unified, alpha_unified, beta_unified

  def _extend(self, xq: Array, q: Array, m: Array, f_domain: Array, deterministic: bool = False) -> Array:
    batched_slice = jax.vmap(lambda f, m: f[jnp.where(m, size=self.boundary_size)[0]])
    batched_padmask = jax.vmap(lambda m: jnp.where(jnp.where(m, size=self.boundary_size, fill_value=-1)[0] > -1, 1, 0))
    xs_bnd = batched_slice(xq.squeeze(1), m.squeeze(1))
    q_bnd = batched_slice(q.squeeze(1), m.squeeze(1))
    m_bnd = batched_padmask(m.squeeze(1))  # A binary mask indicating the padded boundary entries
    # Shuffle on the coordinate axis
    if not deterministic:
      rngkey = self.make_rng(name='other')
      xs_bnd, q_bnd, m_bnd = shuffle_arrays(rngkey=rngkey, arrays=[xs_bnd, q_bnd, m_bnd], axis=1)
    # Use the same extender twice to make the extension even
    if self.even:
      assert not self.independent
      psi_pos = self.extender(
        f_boundary=jnp.concatenate([xs_bnd, q_bnd], axis=-1),
        f_domain=f_domain,
        m_boundary=m_bnd,
        deterministic=deterministic,
      )
      psi_neg = self.extender(
        f_boundary=jnp.concatenate([xs_bnd, -q_bnd], axis=-1),
        f_domain=f_domain,
        m_boundary=m_bnd,
        deterministic=deterministic,
      )
      psi = psi_pos + psi_neg  # even function
    # Use the extender only once
    else:
      psi = self.extender(
        f_boundary=jnp.concatenate([xs_bnd, q_bnd], axis=-1),
        f_domain=f_domain,
        m_boundary=m_bnd,
        deterministic=deterministic,
      )

    return psi
