
"""Constructors for MLPs.

To enable multi-GPU sharding, create a mesh at the top level and pass it to all modules:

"""
import jax
import jax.numpy as jnp
import jraph
import flax.nnx as nnx
from jax.sharding import NamedSharding, PartitionSpec as P

import functools
from typing import Optional


### Flax NNX modules ###
class LinearNormConditioning(nnx.Module):
  def __init__(self, feature_size: int, rngs: nnx.Rngs, mesh, conditioning_dim: int):
    self.feature_size = feature_size
    kernel_init = nnx.with_partitioning(
        nnx.initializers.truncated_normal(stddev=1e-8),
        P(None, 'model')
    )
    bias_init = nnx.with_partitioning(
        nnx.initializers.zeros_init(),
        P('model')
    )
    self.conditional_linear_layer = nnx.Linear(
      in_features=conditioning_dim,
      out_features=2 * feature_size,
      kernel_init=kernel_init,
      bias_init=bias_init,
      rngs=rngs,
    )
  
  def __call__(self, inputs: jax.Array, norm_conditioning: jax.Array):
    # inputs: (..., C)
    # norm_conditioning: (..., D) broadcastable to inputs[..., None, :]
    cond = self.conditional_linear_layer(norm_conditioning)  # (..., 2*C)
    scale_minus_one, offset = jnp.split(cond, 2, axis=-1)
    scale = scale_minus_one + 1.
    return inputs * scale + offset
  

class MLPWithNormConditioning(nnx.Module):
  def __init__(self,
               mlp_input_size: int,
               mlp_hidden_size: int,
               mlp_num_hidden_layers: int,
               mlp_output_size: int,
               activation,
               *,
               use_layer_norm: bool,
               use_norm_conditioning: bool,
               rngs: nnx.Rngs,
               mesh,
               norm_conditioning_dim: Optional[int] = None):
    self._use_layer_norm = use_layer_norm
    self._use_norm_conditioning = use_norm_conditioning

    self.network = MLP(
        mlp_input_size=mlp_input_size,
        mlp_hidden_size=mlp_hidden_size,
        mlp_num_hidden_layers=mlp_num_hidden_layers,
        mlp_output_size=mlp_output_size,
        activation=activation,
        rngs=rngs,
        mesh=mesh,
    )

    if self._use_layer_norm:
      self.layer_norm = nnx.LayerNorm(
          num_features=mlp_output_size,
          use_scale=not use_norm_conditioning,
          use_bias=not use_norm_conditioning,
          feature_axes=-1,
          scale_init=nnx.initializers.ones_init() if not use_norm_conditioning else None,
          bias_init=nnx.initializers.zeros_init() if not use_norm_conditioning else None,
          rngs=rngs
      )

    if self._use_norm_conditioning:
      if norm_conditioning_dim is None:
        raise ValueError("norm_conditioning_dim must be set when norm conditioning is enabled.")
      self.norm_conditioning_layer = LinearNormConditioning(
          feature_size=mlp_output_size,
          conditioning_dim=norm_conditioning_dim,
          rngs=rngs,
          mesh=mesh
      )

  def __call__(self, inputs: jax.Array, global_norm_conditioning: Optional[jax.Array] = None):
    if self._use_norm_conditioning and global_norm_conditioning is None:
      raise ValueError("global_norm_conditioning must be provided when norm conditioning is enabled.")

    x = self.network(inputs)

    if self._use_layer_norm:
      x = self.layer_norm(x)

    if self._use_norm_conditioning:
      # Expect global_norm_conditioning of shape (B, D)
      # Match to inputs of shape (N, B, C) or (B, N, C)
      if x.ndim == 3:
        # Case (N,B,C)
        if x.shape[1] == global_norm_conditioning.shape[0]:
          cond = global_norm_conditioning[None, :, :]  # (1,B,D)
        # Case (B,N,C)
        elif x.shape[0] == global_norm_conditioning.shape[0]:
          cond = global_norm_conditioning[:, None, :]  # (B,1,D)
        else:
          raise ValueError(f"Cannot align conditioning {global_norm_conditioning.shape} with x {x.shape}")
      elif x.ndim == 2:
        # Case (B,C)
        if x.shape[0] == global_norm_conditioning.shape[0]:
          cond = global_norm_conditioning
        else:
          raise ValueError(f"Cannot align conditioning {global_norm_conditioning.shape} with x {x.shape}")
      else:
        raise ValueError(f"Unsupported input shape {x.shape} for norm conditioning")

      x = self.norm_conditioning_layer(x, cond)

    return x

  


class MLP(nnx.Module): 
  """A simple MLP module."""
  def __init__(self,
               mlp_input_size: int,
               mlp_hidden_size: int,
               mlp_num_hidden_layers: int,
               mlp_output_size: int,
               activation,
               *,
               rngs: nnx.Rngs,
               mesh):
    """Initializes the MLP module."""
    layers = []
    feature_size = mlp_input_size
    for _ in range(mlp_num_hidden_layers):
      kernel_init = nnx.with_partitioning(
          nnx.initializers.xavier_uniform(),
          P(None, 'model')
      )
      bias_init = nnx.with_partitioning(
          nnx.initializers.zeros_init(),
          P('model')
      )
      layers.append(nnx.Linear(
          in_features=feature_size, 
          out_features=mlp_hidden_size, 
          kernel_init=kernel_init,
          bias_init=bias_init,
          rngs=rngs
      ))
      feature_size = mlp_hidden_size
      layers.append(activation)
    # Final layer
    kernel_init = nnx.with_partitioning(
        nnx.initializers.xavier_uniform(),
        P(None, 'model')
    )
    bias_init = nnx.with_partitioning(
        nnx.initializers.zeros_init(),
        P('model')
    )
    layers.append(nnx.Linear(
        in_features=mlp_hidden_size, 
        out_features=mlp_output_size, 
        kernel_init=kernel_init,
        bias_init=bias_init,
        rngs=rngs
    ))
    self.network = nnx.Sequential(*layers)
  
  def __call__(self, inputs: jax.Array):
    return self.network(inputs)

# import jax
# import jax.numpy as jnp
# import jraph
# import flax.nnx as nnx
# from jax.sharding import NamedSharding, PartitionSpec as P

# import functools
# from typing import Optional


# def _silu(x: jax.Array) -> jax.Array:
#   return x * jax.nn.sigmoid(x)


# ### Flax NNX modules ###
# class LinearNormConditioning(nnx.Module):
#   def __init__(self, feature_size: int, rngs: nnx.Rngs, mesh, conditioning_dim: int):
#     self.feature_size = feature_size
#     kernel_init = nnx.with_partitioning(
#         nnx.initializers.truncated_normal(stddev=1e-8),
#         P(None, 'model')
#     )
#     bias_init = nnx.with_partitioning(
#         nnx.initializers.zeros_init(),
#         P('model')
#     )
#     self.conditional_linear_layer = nnx.Linear(
#       in_features=conditioning_dim,
#       out_features=2 * feature_size,
#       kernel_init=kernel_init,
#       bias_init=bias_init,
#       rngs=rngs,
#     )
  
#   def __call__(self, inputs: jax.Array, norm_conditioning: jax.Array):
#     # inputs: (..., C)
#     # norm_conditioning: (..., D) broadcastable to inputs[..., None, :]
#     cond = self.conditional_linear_layer(norm_conditioning)  # (..., 2*C)
#     scale_minus_one, offset = jnp.split(cond, 2, axis=-1)
#     scale = scale_minus_one + 1.
#     return inputs * scale + offset
  

# class MLPWithNormConditioning(nnx.Module):
#   def __init__(self,
#                mlp_input_size: int,
#                mlp_hidden_size: int,
#                mlp_num_hidden_layers: int,
#                mlp_output_size: int,
#                activation,  # kept for API compatibility; ignored in SwiGLU path
#                *,
#                use_layer_norm: bool,
#                use_norm_conditioning: bool,
#                rngs: nnx.Rngs,
#                mesh,
#                norm_conditioning_dim: Optional[int] = None):
#     self._use_layer_norm = use_layer_norm
#     self._use_norm_conditioning = use_norm_conditioning

#     # === SwiGLU MLP ===
#     self.network = SwiGLUMLP(
#         mlp_input_size=mlp_input_size,
#         mlp_hidden_size=mlp_hidden_size,
#         mlp_num_hidden_layers=mlp_num_hidden_layers,
#         mlp_output_size=mlp_output_size,
#         rngs=rngs,
#         mesh=mesh,
#     )

#     if self._use_layer_norm:
#       self.layer_norm = nnx.LayerNorm(
#           num_features=mlp_output_size,
#           use_scale=not use_norm_conditioning,
#           use_bias=not use_norm_conditioning,
#           feature_axes=-1,
#           scale_init=nnx.initializers.ones_init() if not use_norm_conditioning else None,
#           bias_init=nnx.initializers.zeros_init() if not use_norm_conditioning else None,
#           rngs=rngs
#       )

#     if self._use_norm_conditioning:
#       if norm_conditioning_dim is None:
#         raise ValueError("norm_conditioning_dim must be set when norm conditioning is enabled.")
#       self.norm_conditioning_layer = LinearNormConditioning(
#           feature_size=mlp_output_size,
#           conditioning_dim=norm_conditioning_dim,
#           rngs=rngs,
#           mesh=mesh
#       )

#   def __call__(self, inputs: jax.Array, global_norm_conditioning: Optional[jax.Array] = None):
#     if self._use_norm_conditioning and global_norm_conditioning is None:
#       raise ValueError("global_norm_conditioning must be provided when norm conditioning is enabled.")

#     x = self.network(inputs)

#     if self._use_layer_norm:
#       x = self.layer_norm(x)

#     if self._use_norm_conditioning:
#       # Expect global_norm_conditioning of shape (B, D)
#       # Match to inputs of shape (N, B, C) or (B, N, C) or (B, C)
#       if x.ndim == 3:
#         if x.shape[1] == global_norm_conditioning.shape[0]:      # (N,B,C)
#           cond = global_norm_conditioning[None, :, :]            # (1,B,D)
#         elif x.shape[0] == global_norm_conditioning.shape[0]:    # (B,N,C)
#           cond = global_norm_conditioning[:, None, :]            # (B,1,D)
#         else:
#           raise ValueError(f"Cannot align conditioning {global_norm_conditioning.shape} with x {x.shape}")
#       elif x.ndim == 2:
#         if x.shape[0] == global_norm_conditioning.shape[0]:      # (B,C)
#           cond = global_norm_conditioning
#         else:
#           raise ValueError(f"Cannot align conditioning {global_norm_conditioning.shape} with x {x.shape}")
#       else:
#         raise ValueError(f"Unsupported input shape {x.shape} for norm conditioning")

#       x = self.norm_conditioning_layer(x, cond)

#     return x


# class SwiGLULayer(nnx.Module):
#   """One SwiGLU FFN block: up-proj -> SwiGLU -> down-proj."""
#   def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, rngs: nnx.Rngs):
#     up_kernel_init = nnx.with_partitioning(
#         nnx.initializers.xavier_uniform(),
#         P(None, 'model')
#     )
#     up_bias_init = nnx.with_partitioning(
#         nnx.initializers.zeros_init(),
#         P('model')
#     )
#     down_kernel_init = nnx.with_partitioning(
#         nnx.initializers.xavier_uniform(),
#         P(None, 'model')
#     )
#     down_bias_init = nnx.with_partitioning(
#         nnx.initializers.zeros_init(),
#         P('model')
#     )

#     # Up-projection to 2*hidden_dim, then split for SwiGLU
#     self.up = nnx.Linear(
#         in_features=in_dim,
#         out_features=2 * hidden_dim,
#         kernel_init=up_kernel_init,
#         bias_init=up_bias_init,
#         rngs=rngs
#     )
#     # Down-projection to output dimension
#     self.down = nnx.Linear(
#         in_features=hidden_dim,
#         out_features=out_dim,  # Fixed: was hidden_dim, should be out_dim
#         kernel_init=down_kernel_init,
#         bias_init=down_bias_init,
#         rngs=rngs
#     )

#   def __call__(self, x: jax.Array) -> jax.Array:
#     u = self.up(x)                               # (..., 2*hidden_dim)
#     a, b = jnp.split(u, 2, axis=-1)              # each (..., hidden_dim)
#     g = _silu(a) * b                              # SwiGLU activation
#     return self.down(g)                           # (..., out_dim)


# class SwiGLUMLP(nnx.Module):
#   """SwiGLU MLP stack that mirrors the original MLP API."""
#   def __init__(self,
#                mlp_input_size: int,
#                mlp_hidden_size: int,
#                mlp_num_hidden_layers: int,
#                mlp_output_size: int,
#                *,
#                rngs: nnx.Rngs,
#                mesh):
#     layers = []
#     current_dim = mlp_input_size

#     # Build N hidden SwiGLU blocks
#     for i in range(mlp_num_hidden_layers):
#       if i == mlp_num_hidden_layers - 1:
#         # Last layer: project to final output size
#         next_dim = mlp_output_size
#       else:
#         # Intermediate layers: project to hidden size
#         next_dim = mlp_hidden_size
      
#       layers.append(SwiGLULayer(
#           in_dim=current_dim, 
#           hidden_dim=mlp_hidden_size, 
#           out_dim=next_dim,
#           rngs=rngs
#       ))
#       current_dim = next_dim

#     self.network = nnx.Sequential(*layers)

#   def __call__(self, inputs: jax.Array) -> jax.Array:
#     return self.network(inputs)
