from typing import Tuple, Union, NamedTuple, Mapping

import flax.typing
import jax.numpy as jnp
import jax.random
import numpy as np
import jraph
from flax import linen as nn
from scipy.spatial import Delaunay

from ol.graph.entities import TypedGraph, EdgeSet, EdgeSetKey, EdgesIndices, NodeSet, Context
from ol.graph.graphbuilder import GraphSet
from ol.graph.graphnet import DeepTypedGraphNet
from ol.models.common import AbstractOperator, Inputs
from ol.models.extender import CrossAttentionExtender
from ol.utils import Array, shuffle_arrays


class Encoder(nn.Module):
  node_latent_size: int
  edge_latent_size: int
  mlp_hidden_layers: int = 1
  use_layer_norm: bool = True
  conditioned_normalization: bool = True
  cond_norm_hidden_size: bool = True
  p_edge_masking: float = .0

  def setup(self):
    self.gnn = DeepTypedGraphNet(
      embed_nodes=True,  # Embed raw features of all nodes
      embed_edges=True,  # Embed raw features of the edges
      edge_latent_size=dict(p2r=self.edge_latent_size),
      node_latent_size=dict(rnodes=self.node_latent_size, pnodes=self.node_latent_size),
      mlp_num_hidden_layers=self.mlp_hidden_layers,
      num_message_passing_steps=1,
      use_layer_norm=self.use_layer_norm,
      conditioned_normalization=self.conditioned_normalization,
      cond_norm_hidden_size=self.cond_norm_hidden_size,
      include_sent_messages_in_node_update=False,
      activation='swish',
      f32_aggregation=True,
      aggregate_edges_for_nodes_fn=jraph.segment_mean,
    )

  def __call__(self,
    graph: TypedGraph,
    input_pnode_features: Array,
    input_rnode_features: Array,
    tau: Union[None, float],
    deterministic: bool = False,
  ) -> tuple[Array, Array]:
    """Runs the p2r GNN, extracting latent physical and regional nodes."""

    # Get batch size
    batch_size = input_pnode_features.shape[0]

    # Concatenate node structural features with input features
    pnodes = graph.nodes['pnodes']
    rnodes = graph.nodes['rnodes']
    new_pnodes = pnodes._replace(features=jnp.concatenate([input_pnode_features, pnodes.features], axis=-1))
    new_rnodes = rnodes._replace(features=jnp.concatenate([input_rnode_features, rnodes.features], axis=-1))

    # Get edges
    p2r_edges_key = graph.edge_key_by_name('p2r')
    edges = graph.edges[p2r_edges_key]
    # Drop out edges randomly with the given probability
    if deterministic:
      n_edges_after = edges.features.shape[1]
      new_edge_features = edges.features
      new_edge_senders = edges.indices.senders
      new_edge_receivers = edges.indices.receivers
    else:
      rngkey = self.make_rng('masking')
      n_edges_after = int((1 - self.p_edge_masking) * edges.features.shape[1])
      [new_edge_features, new_edge_senders, new_edge_receivers] = shuffle_arrays(
        rngkey=rngkey, arrays=[edges.features, edges.indices.senders, edges.indices.receivers], axis=1)
      new_edge_features = new_edge_features[:, :n_edges_after]
      new_edge_senders = new_edge_senders[:, :n_edges_after]
      new_edge_receivers = new_edge_receivers[:, :n_edges_after]
    # Change edge feature dtype
    new_edge_features = new_edge_features.astype(input_pnode_features.dtype)
    # Build new edge set
    new_edges = EdgeSet(
      n_edge=jnp.tile(jnp.array([n_edges_after]), reps=(batch_size, 1)),
      indices=EdgesIndices(senders=new_edge_senders, receivers=new_edge_receivers),
      features=new_edge_features,
    )

    input_graph = graph._replace(
      edges={p2r_edges_key: new_edges},
      nodes={'pnodes': new_pnodes, 'rnodes': new_rnodes}
    )

    # Run the GNN
    p2r_out = self.gnn(input_graph, condition=tau)
    latent_rnodes = p2r_out.nodes['rnodes'].features
    latent_pnodes = p2r_out.nodes['pnodes'].features

    return latent_rnodes, latent_pnodes

class Processor(nn.Module):
  steps: int
  node_latent_size: int
  edge_latent_size: int
  mlp_hidden_layers: int = 1
  use_layer_norm: bool = True
  conditioned_normalization: bool = True
  cond_norm_hidden_size: bool = True
  p_edge_masking: float = .0

  def setup(self):
    self.gnn = DeepTypedGraphNet(
      embed_nodes=False,  # Node features already embdded by previous layers
      embed_edges=True,  # Embed raw features of the edges
      edge_latent_size=dict(r2r=self.edge_latent_size),
      node_latent_size=dict(rnodes=self.node_latent_size),
      mlp_num_hidden_layers=self.mlp_hidden_layers,
      num_message_passing_steps=self.steps,
      use_layer_norm=True,
      conditioned_normalization=self.conditioned_normalization,
      cond_norm_hidden_size=self.cond_norm_hidden_size,
      include_sent_messages_in_node_update=False,
      activation='swish',
      f32_aggregation=False,
      # NOTE: segment_mean because number of edges is not balanced
      aggregate_edges_for_nodes_fn=jraph.segment_mean,
    )

  def __call__(self,
    graph: TypedGraph,
    rnode_features: Array,
    tau: Union[None, float],
    deterministic: bool = False,
  ) -> Array:
    """Runs the r2r GNN, extracting updated latent regional nodes."""

    # Get batch size
    batch_size = rnode_features.shape[0]

    # Replace the node features
    # NOTE: We don't need to add the structural node features, because these are
    # already part of  the latent state, via the original p2r gnn.
    rnodes = graph.nodes['rnodes']
    new_rnodes = rnodes._replace(features=rnode_features)

    # Get edges
    r2r_edges_key = graph.edge_key_by_name('r2r')
    # NOTE: We are assuming here that the r2r gnn uses a single set of edge keys
    # named 'r2r' for the edges and that it uses a single set of nodes named 'rnodes'
    msg = ('The setup currently requires to only have one kind of edge in the mesh GNN.')
    assert len(graph.edges) == 1, msg
    edges = graph.edges[r2r_edges_key]
    # Drop out edges randomly with the given probability
    # NOTE: We need the structural edge features, because it is the first
    # time we are seeing this particular set of edges.
    if deterministic:
      n_edges_after = edges.features.shape[1]
      new_edge_features = edges.features
      new_edge_senders = edges.indices.senders
      new_edge_receivers = edges.indices.receivers
    else:
      rngkey = self.make_rng('masking')
      n_edges_after = int((1 - self.p_edge_masking) * edges.features.shape[1])
      [new_edge_features, new_edge_senders, new_edge_receivers] = shuffle_arrays(
        rngkey=rngkey, arrays=[edges.features, edges.indices.senders, edges.indices.receivers], axis=1)
      new_edge_features = new_edge_features[:, :n_edges_after]
      new_edge_senders = new_edge_senders[:, :n_edges_after]
      new_edge_receivers = new_edge_receivers[:, :n_edges_after]
    # Change edge feature dtype
    new_edge_features = new_edge_features.astype(rnode_features.dtype)
    # Build new edge set
    new_edges = EdgeSet(
      n_edge=jnp.tile(jnp.array([n_edges_after]), reps=(batch_size, 1)),
      indices=EdgesIndices(
        senders=new_edge_senders,
        receivers=new_edge_receivers,
      ),
      features=new_edge_features,
    )

    # Build the graph
    input_graph = graph._replace(
      edges={r2r_edges_key: new_edges},
      nodes={'rnodes': new_rnodes},
    )

    # Run the GNN
    output_graph = self.gnn(input_graph, condition=tau)
    output_rnodes = output_graph.nodes['rnodes'].features

    return output_rnodes

class Decoder(nn.Module):
  variable_mesh: bool
  num_outputs: int
  node_latent_size: int
  edge_latent_size: int
  mlp_hidden_layers: int = 1
  use_layer_norm: bool = True
  conditioned_normalization: bool = True
  cond_norm_hidden_size: bool = True
  p_edge_masking: float = .0

  def setup(self):
    self.gnn = DeepTypedGraphNet(
      # NOTE: with variable mesh, the output pnode features must be embedded
      embed_nodes=(dict(pnodes=True) if self.variable_mesh else False),
      embed_edges=True,  # Embed raw features of the edges
      # Require a specific node dimensionaly for the physical node outputs
      # NOTE: This triggers the independent mapping for pnodes
      node_output_size=dict(pnodes=self.num_outputs),
      edge_latent_size=dict(r2p=self.edge_latent_size),
      node_latent_size=dict(rnodes=self.node_latent_size, pnodes=self.node_latent_size),
      mlp_num_hidden_layers=self.mlp_hidden_layers,
      num_message_passing_steps=1,
      use_layer_norm=True,
      conditioned_normalization=self.conditioned_normalization,
      cond_norm_hidden_size=self.cond_norm_hidden_size,
      include_sent_messages_in_node_update=False,
      activation='swish',
      f32_aggregation=False,
      # NOTE: segment_mean because number of edges is not balanced
      aggregate_edges_for_nodes_fn=jraph.segment_mean,
    )

  def __call__(self,
    graph: TypedGraph,
    rnode_features: Array,
    pnode_features: Array,
    tau: Union[None, float],
    deterministic: bool = False,
  ) -> Array:
    """Runs the r2p GNN, extracting the output physical nodes."""

    # Get batch size
    batch_size = rnode_features.shape[0]

    # NOTE: We don't need to add the structural node features, because these are
    # already part of the latent state, via the original p2r gnn.
    rnodes = graph.nodes['rnodes']
    pnodes = graph.nodes['pnodes']
    new_rnodes = rnodes._replace(features=rnode_features)
    if self.variable_mesh:
      # NOTE: We can't use latent pnodes of the input mesh for the output mesh
      new_pnodes = pnodes._replace(features=pnodes.features)
    else:
      new_pnodes = pnodes._replace(features=pnode_features)

    # Get edges
    r2p_edges_key = graph.edge_key_by_name('r2p')
    edges = graph.edges[r2p_edges_key]
    # Drop out edges randomly with the given probability
    if deterministic:
      n_edges_after = edges.features.shape[1]
      new_edge_features = edges.features
      new_edge_senders = edges.indices.senders
      new_edge_receivers = edges.indices.receivers
    else:
      rngkey = self.make_rng('masking')
      n_edges_after = int((1 - self.p_edge_masking) * edges.features.shape[1])
      [new_edge_features, new_edge_senders, new_edge_receivers] = shuffle_arrays(
        rngkey=rngkey, arrays=[edges.features, edges.indices.senders, edges.indices.receivers], axis=1)
      new_edge_features = new_edge_features[:, :n_edges_after]
      new_edge_senders = new_edge_senders[:, :n_edges_after]
      new_edge_receivers = new_edge_receivers[:, :n_edges_after]
    # Change edge feature dtype
    new_edge_features = new_edge_features.astype(pnode_features.dtype)
    # Build new edge set
    new_edges = EdgeSet(
      n_edge=jnp.tile(jnp.array([n_edges_after]), reps=(batch_size, 1)),
      indices=EdgesIndices(
        senders=new_edge_senders,
        receivers=new_edge_receivers,
      ),
      features=new_edge_features,
    )

    # Build the new graph
    input_graph = graph._replace(
      edges={r2p_edges_key: new_edges},
      nodes={'rnodes': new_rnodes, 'pnodes': new_pnodes}
    )

    # Run the GNN
    output_graph = self.gnn(input_graph, condition=tau)
    output_pnodes = output_graph.nodes['pnodes'].features

    return output_pnodes

class RIGNO(AbstractOperator):
  """
  RIGNO: Region Interaction Graph Neural Operator.

  Args:
    num_outputs: Number of output variables.
    processor_steps: Number of message passing blocks in the processor.
    node_latent_size: Dimension of the latent node features.
    edge_latent_size: Dimension of the latent edge features.
    mlp_hidden_layers: Number of hidden layers in the MLPs.
    p_edge_masking: Probability of masking an edge.
  """

  num_outputs: int
  processor_steps: int = 18
  node_latent_size: int = 128
  edge_latent_size: int = 128
  mlp_hidden_layers: int = 1
  p_edge_masking: int = 0.5
  tdep: bool = False

  def setup(self):
    # NOTE: There are a few architectural considerations for variable mesh
    # NOTE: variable_mesh=True means that the input and the output mesh can be different
    self.variable_mesh = False

    self.encoder = Encoder(
      edge_latent_size=self.edge_latent_size,
      node_latent_size=self.node_latent_size,
      mlp_hidden_layers=self.mlp_hidden_layers,
      conditioned_normalization=self.tdep,
      cond_norm_hidden_size=16,
      p_edge_masking=self.p_edge_masking,
      name='encoder',
    )

    self.processor = Processor(
      steps=self.processor_steps,
      edge_latent_size=self.edge_latent_size,
      node_latent_size=self.node_latent_size,
      mlp_hidden_layers=self.mlp_hidden_layers,
      conditioned_normalization=self.tdep,
      cond_norm_hidden_size=16,
      p_edge_masking=self.p_edge_masking,
      name='processor',
    )

    self.decoder = Decoder(
      variable_mesh=self.variable_mesh,
      num_outputs=self.num_outputs,
      edge_latent_size=self.edge_latent_size,
      node_latent_size=self.node_latent_size,
      mlp_hidden_layers=self.mlp_hidden_layers,
      conditioned_normalization=self.tdep,
      cond_norm_hidden_size=16,
      p_edge_masking=self.p_edge_masking,
      name='decoder',
    )

  @staticmethod
  def _prepare_features(feats: Array) -> Array:
    # Expand time axis
    feats = jnp.expand_dims(feats, axis=1)
    return feats

  def _encode_process_decode(self, graphs: GraphSet, input_pnode_features: Array, input_rnode_features: Array, tau: Union[None, float], deterministic: bool = False) -> Array:

    # Add dummy node features
    dummy_pnode_features = jnp.zeros(shape=(input_pnode_features.shape[0], 1, input_pnode_features.shape[2]))
    input_pnode_features = jnp.concatenate([input_pnode_features, dummy_pnode_features], axis=1)

    # Transfer data from the physical mesh to the regional mesh
    # -> [batch_size, num_nodes, latent_size]
    latent_rnode_features, latent_pnode_features = self.encoder(graphs.p2r, input_pnode_features, input_rnode_features, tau, deterministic=deterministic)
    self.sow(col='intermediates', name='pnodes-encoded', value=self._prepare_features(latent_pnode_features[:, :-1]))
    self.sow(col='intermediates', name='rnodes-encoded', value=self._prepare_features(latent_rnode_features[:, :-1]))

    # Run message passing steps in the regional mesh
    # -> [batch_size, num_rnodes, latent_size]
    processed_rnode_features = self.processor(graphs.r2r, latent_rnode_features, tau, deterministic=deterministic)
    self.sow(col='intermediates', name='rnodes-processed', value=self._prepare_features(processed_rnode_features[:, :-1]))

    # Transfer data from the regional mesh to the physical mesh
    # -> [batch_size, num_pnodes_out, latent_size]
    output_pnode_features = self.decoder(graphs.r2p, processed_rnode_features, latent_pnode_features, tau, deterministic=deterministic)
    self.sow(col='intermediates', name='pnodes-decoded', value=self._prepare_features(output_pnode_features[:, :-1]))

    # Remove the dummy node features
    output_pnode_features = output_pnode_features[:, :-1, :]

    return output_pnode_features

  def call(self, inputs: Inputs, graphs: GraphSet, input_pnode_features: Array = None, input_rnode_features: Array = None, deterministic: bool = False) -> Array:
    """Inputs must be of shape [batch_size, 1, num_physical_nodes, num_variables]"""

    # Read dimensions
    batch_size = inputs.a.shape[0]
    num_pnodes_inp = inputs.x_inp.shape[2]
    num_pnodes_out = inputs.x_out.shape[2]
    num_rnodes = graphs.p2r.nodes['rnodes'].features.shape[1]

    if self.tdep:
      # Prepare the time channel
      assert inputs.t is not None
      t = jnp.array(inputs.t, dtype=jnp.float32)
      if t.ndim == 4:
        t = t[:, :, 0, 0]
      if t.size == 1:
        t = jnp.tile(t.reshape(1, 1), reps=(batch_size, 1))
      # Prepare the lead time channel
      assert inputs.tau is not None
      tau = jnp.array(inputs.tau, dtype=jnp.float32)
      if tau.ndim == 4:
        tau = tau[:, :, 0, 0]
      if tau.size == 1:
        tau = jnp.tile(tau.reshape(1, 1), reps=(batch_size, 1))
    else:
      tau = None

    # Prepare the physical node features
    # q, u -> [batch_size, num_pnodes_inp, num_inputs]
    pnode_features = inputs.a.squeeze(1)
    # Concatente with forced and input features
    pnode_features_forced = []
    if input_pnode_features is not None:
      pnode_features_forced.append(input_pnode_features)
    if self.tdep:
      pnode_features_forced.append(jnp.tile(jnp.expand_dims(t, axis=1), reps=(1, num_pnodes_inp, 1)))
      pnode_features_forced.append(jnp.tile(jnp.expand_dims(tau, axis=1), reps=(1, num_pnodes_inp, 1)))
    pnode_features = jnp.concatenate([pnode_features, *pnode_features_forced], axis=-1)

    # Prepare the input regional features
    if input_rnode_features is None:
      rnode_features = jnp.zeros(shape=(batch_size, num_rnodes, 1), dtype=inputs.a.dtype)
    else:
      rnode_features = input_rnode_features

    # Pass through the GNNs
    output_pnodes = self._encode_process_decode(graphs=graphs, input_pnode_features=pnode_features, input_rnode_features=rnode_features, tau=tau, deterministic=deterministic)

    # Reshape the output
    # [batch_size, num_pnodes_out, num_outputs] -> [batch_size, 1, num_pnodes_out, num_outputs]
    output = self._prepare_features(output_pnodes)
    self._check_function(output, x=inputs.x_out)

    return output

class XRIGNO(AbstractOperator):

  configs_core: Mapping
  configs_extender: Mapping
  use_extender: bool
  independent: bool
  boundary_size: int
  unify: bool = True
  even: bool = False

  def setup(self):
    self.operator = RIGNO(**self.configs_core)
    self.extender = CrossAttentionExtender(**self.configs_extender)

  @nn.compact
  def call(self, inputs: Inputs, graphs: GraphSet, deterministic: bool = False):
    """Extended RIGNO"""

    # Unify all the boundary conditions into a Robin form
    q, m  = dict(), dict()
    if self.unify:
      prefixes = {'-'.join(key.split('-')[:2]) for key in inputs.q.keys()}  # NOTE: assuming a structure as "bc-d-typ" in the names
      prefixes = sorted(list(prefixes))  # NOTE: otherwise it leads to random ordering of the prefixes
      for prefix in prefixes:
        m_prefix, g_prefix, alpha_prefix, beta_prefix = self._unify_all_boundary_conditions(
          q={key: val for key, val in inputs.q.items() if key.startswith(prefix)},
          m={key: val for key, val in inputs.m.items() if key.startswith(prefix)},
        )
        q[f'{prefix}-uni'] = jnp.concatenate([g_prefix, alpha_prefix, beta_prefix], axis=-1)
        m[f'{prefix}-uni'] = m_prefix
      # NOTE: Concatenate all and use a single CrossAttentionExtender on them
      # NOTE: Only viable if the masks exactly match
      q = {'bc-uni': jnp.concatenate(list(q.values()), axis=-1)}
      m = {'bc-uni': np.any(jnp.stack(list(m.values()), axis=-1), axis=-1)}
    # Or leave them as they are
    else:
      q, m = inputs.q, inputs.m

    if self.use_extender:
      # Project all the boundary functions (e.g., bc-0-dir, bc-1-dir) separately
      xq = jnp.concatenate([inputs.x_inp, inputs.s], axis=-1)  # Geometric features of the whole domain (to be used for boundaries)
      rnodes = graphs.p2r.nodes['rnodes'].features  # Contains only positional encodings of the rnodes
      if self.independent:
        def extend_group(_q, _m):
          _q_batched_components = jnp.moveaxis(_q, -1, 1).reshape(-1, *_q.shape[1:-1])[..., None]
          _m_batched_components = jnp.repeat(_m, repeats=_q.shape[-1], axis=0)
          _rnodes_batched_components = jnp.repeat(rnodes, repeats=_q.shape[-1], axis=0)
          _xq_batched_components = jnp.repeat(xq, repeats=_q.shape[-1], axis=0)
          _extensions = self._extend(_xq_batched_components, _q_batched_components, _m_batched_components, f_domain=_rnodes_batched_components, deterministic=deterministic)
          _extensions = jnp.moveaxis(_extensions.reshape(_q.shape[0], -1, *_extensions.shape[1:]), 1, -1)
          _extensions = _extensions.reshape(*_extensions.shape[:-2], -1)
          return _extensions
      else:
        def extend_group(_q, _m):
          _extensions = self._extend(xq, _q, _m, f_domain=rnodes, deterministic=deterministic)
          return _extensions
      extensions = jax.tree.map(extend_group, q, m)
      # Concatenate all the extensions
      jax.tree.map_with_path(lambda key, psi: self.sow(col='intermediates', name=f'extensions-{key[0].key}', value=psi), extensions)
      extensions = jnp.concatenate(jax.tree.flatten(extensions)[0], axis=-1)
      # Pass through the operator
      # NOTE: feeding in rnode features
      output = self.operator(inputs, graphs=graphs, input_rnode_features=extensions, deterministic=deterministic)
    else:
      # Add a variable dimension to the masks
      m = jax.tree.map(lambda m: m[..., None], m)
      # Extend the boundary functions with zeros and add a binary mask
      extensions = jax.tree.map(lambda _q, _m: jnp.concatenate([_q, _m.astype(float)], axis=-1), q, m)
      extensions = jnp.concatenate(jax.tree.flatten(extensions)[0], axis=-1)
      extensions = jnp.squeeze(extensions, axis=1)

      # Pass through the operator
      # NOTE: feeding in pnode features
      output = self.operator(inputs, graphs=graphs, input_pnode_features=extensions, deterministic=deterministic)

    return output

  def _unify_all_boundary_conditions(self, q, m):
    """Unifies non-overlapping Dirichlet, Neumann, and Robin BCs into a Robin form."""

    # Get unified mask
    # NOTE: assuming that the masks do not overlap
    m_unified = jnp.any(jnp.stack(jax.tree.leaves(m), axis=-1), axis=-1)

    # Get unified g
    # NOTE: assuming g is always the first function
    gm = jax.tree.map(lambda m, q: m * q[..., 0], m, q)
    g_unified = jnp.stack(jax.tree.leaves(gm), axis=-1).sum(axis=-1, keepdims=True)

    # Get alpha and beta for all BC types
    # NOTE: assuming alpha is only present in Robin BC as the second function
    ones = jax.tree.map(lambda q: jnp.concatenate([q[..., 1:], jnp.ones(shape=(*q.shape[:3], 3-q.shape[3]))], axis=-1), q)  # contains alpha=1 and beta=1
    alpha = jax.tree.map_with_path(lambda key, _ones: jnp.zeros_like(_ones[..., 0]) if (key[0].key.endswith('neu')) else _ones[..., 0], ones)  # alpha=0 for Neumann
    beta = jax.tree.map_with_path(lambda key, _ones: jnp.zeros_like(_ones[..., 1]) if (key[0].key.endswith('dir')) else _ones[..., 1], ones)  # beta=0 for Dirichlet
    # Get unified alpha and beta
    am = jax.tree.map(lambda m, a: m * a, m, alpha)
    bm = jax.tree.map(lambda m, b: m * b, m, beta)
    alpha_unified = jnp.stack(jax.tree.leaves(am), axis=-1).sum(axis=-1, keepdims=True)
    beta_unified = jnp.stack(jax.tree.leaves(bm), axis=-1).sum(axis=-1, keepdims=True)
    # Normalize alpha, beta, and g with respect to the 2-norm of the vector [alpha, beta]
    zeta_norm = jnp.linalg.norm(jnp.stack([alpha_unified, beta_unified]), axis=0)
    zeta_norm = jnp.where(m_unified[:, :, :, None], zeta_norm, 1.)  # avoid division by zero
    alpha_unified /= zeta_norm
    beta_unified /= zeta_norm
    g_unified /= zeta_norm

    return m_unified, g_unified, alpha_unified, beta_unified

  def _extend(self, xq: Array, q: Array, m: Array, f_domain: Array, deterministic: bool = False) -> Array:
    batched_slice = jax.vmap(lambda f, m: f[jnp.where(m, size=self.boundary_size)[0]])
    batched_padmask = jax.vmap(lambda m: jnp.where(jnp.where(m, size=self.boundary_size, fill_value=-1)[0] > -1, 1, 0))
    xs_bnd = batched_slice(xq.squeeze(1), m.squeeze(1))
    q_bnd = batched_slice(q.squeeze(1), m.squeeze(1))
    m_bnd = batched_padmask(m.squeeze(1))  # A binary mask indicating the padded boundary entries
    # Shuffle on the coordinate axis
    if not deterministic:
      rngkey = self.make_rng(name='other')
      xs_bnd, q_bnd, m_bnd = shuffle_arrays(rngkey=rngkey, arrays=[xs_bnd, q_bnd, m_bnd], axis=1)
    # Use the same extender twice to make the extension even
    if self.even:
      assert not self.independent
      psi_pos = self.extender(
        f_boundary=jnp.concatenate([xs_bnd, q_bnd], axis=-1),
        f_domain=f_domain,
        m_boundary=m_bnd,
        deterministic=deterministic,
      )
      psi_neg = self.extender(
        f_boundary=jnp.concatenate([xs_bnd, -q_bnd], axis=-1),
        f_domain=f_domain,
        m_boundary=m_bnd,
        deterministic=deterministic,
      )
      psi = psi_pos + psi_neg  # even function
    # Use the extender only once
    else:
      psi = self.extender(
        f_boundary=jnp.concatenate([xs_bnd, q_bnd], axis=-1),
        f_domain=f_domain,
        m_boundary=m_bnd,
        deterministic=deterministic,
      )

    return psi
