"""
Flax port of `nequix.py` (originally Equinox + Jraph) in the style of `nequip.py`.

This implementation follows the existing EGXC `BaseGNN` interface:
  (atom_features, atom_pos, atom_mask) -> (graph_readout_energy, atom_features_out)

Compared to the Equinox version:
- We build a dense (all-to-all) graph like `NequIP`, masking padding atoms and self edges.
- No species specific skip connections, but instead a single skip connection from the input features.
- different polynomial envelope function is used for the radial basis functions.
"""

from typing import Callable, Tuple

import e3nn_jax as e3nn
import flax.linen as nn
import jax
import jax.numpy as jnp

from egxc.utils.linalg import safe_norm
from egxc.utils.typing import PRECISION, BoolA, Float1, FloatAx3, RBFType

from .gnn import BaseGNN, GraphReadout
from .mlp import MLP
from .radial_basis_fns import radial_basis_values


class RMSLayerNorm(nn.Module):
    """
    RMS layer norm over an `e3nn.IrrepsArray`.

    Based on https://github.com/facebookresearch/fairchem/blob/977a803/src/fairchem/core/models/esen/nn/layer_norm.py#L229
    from ESEN which is based on EquiformerV2.

    Attributes:
        irreps: The irreducible representations specification for the input.
        centering: If True, subtract mean from scalar (l=0) channels before normalization
            and add learnable bias after. Generally not used for equivariant features.
        std_balance_degrees: If True, weight each irrep's contribution to the norm by
            1/(dim * num_irreps), giving equal importance to each degree regardless of
            its dimensionality (2l+1). This prevents higher-l features from dominating
            the normalization statistics.
        eps: Small constant for numerical stability in the inverse square root.
        affine: If True, apply learnable per-channel scale (and bias for l=0 when centering).
    """

    irreps: e3nn.Irreps
    centering: bool = False
    std_balance_degrees: bool = True
    eps: float = 1e-12
    affine: bool = True

    @nn.compact
    def __call__(self, node_input: e3nn.IrrepsArray) -> e3nn.IrrepsArray:
        input_chunks = []
        norms = []

        for i, (irr, x) in enumerate(zip(node_input.irreps, node_input.chunks)):
            if self.centering and irr.ir.l == 0:
                x = x - jnp.mean(x, axis=-2, keepdims=True)  # type: ignore
            input_chunks.append(x)

            if self.std_balance_degrees:
                weight = 1 / (irr.ir.dim * len(self.irreps))
                # l2 norm
                norm = (x**2 * weight).sum(axis=-2, keepdims=True)  # type: ignore
                # mean over channels
                norm = norm.mean(axis=-1, keepdims=True)
                norms.append(norm)
            else:
                raise NotImplementedError('std_balance_degrees=False not implemented')

        # sum across irreps (because we already weight by 1 / len(self.irreps))
        norm = jnp.concatenate(norms, axis=-1).sum(axis=-1, keepdims=True)
        norm = jnp.pow(norm + self.eps, -0.5)

        out_chunks = []
        for i, (irr, x) in enumerate(zip(node_input.irreps, input_chunks)):
            if self.affine:
                w = self.param(f'affine_weight_{i}', nn.initializers.ones, (irr.mul,))
                x = x * w[:, None]

            out = x * norm

            if self.affine and self.centering and irr.ir.l == 0:
                b = self.param(f'affine_bias_{i}', nn.initializers.zeros, (irr.mul,))
                out = out + b[:, None]

            out_chunks.append(out)

        return e3nn.from_chunks(node_input.irreps, out_chunks, node_input.shape[:-1])


class _NequixLayerFlax(nn.Module):
    """
    One Nequix message passing block (conv + skip + optional norm + gate).

    Note: The original Nequix uses per-species indexed weights in the skip connection.
    This port uses vanilla (non-indexed) linear layers since species indices are not
    available in the BaseGNN interface (atom types are encoded in the input IrrepsArray).
    """

    input_irreps: e3nn.Irreps
    output_irreps: e3nn.Irreps
    sph_irreps: e3nn.Irreps
    radial_basis_size: int
    radial_mlp_size: int
    radial_mlp_layers: int
    mlp_init_scale: float
    avg_num_neighbors: float
    use_layer_norm: bool

    even_act: Callable[[jax.Array], jax.Array] = jax.nn.silu
    odd_act: Callable[[jax.Array], jax.Array] = jax.nn.tanh
    even_gate_act: Callable[[jax.Array], jax.Array] = jax.nn.silu

    def setup(self):
        # Flax converts e3nn.Irreps to tuples, so we need to convert back
        input_irreps = e3nn.Irreps(self.input_irreps)
        output_irreps = e3nn.Irreps(self.output_irreps)
        sph_irreps = e3nn.Irreps(self.sph_irreps)

        self.linear_1 = e3nn.flax.Linear(input_irreps, input_irreps)

        # Radial MLP output size = number of tensor product irreps
        tp_irreps = e3nn.tensor_product(
            input_irreps,  # type: ignore
            sph_irreps,  # type: ignore
            filter_ir_out=output_irreps,  # type: ignore
        )
        self.radial_mlp = MLP(
            dims=(self.radial_mlp_size,) * self.radial_mlp_layers
            + (tp_irreps.num_irreps,),  # type: ignore
            activation=jax.nn.silu,
            dtype=PRECISION.gnn,
            use_bias=False,
            last_layer_use_bias=False,
            kernel_init=nn.initializers.variance_scaling(
                scale=self.mlp_init_scale, mode='fan_in', distribution='normal'
            ),
        )

        # add extra irreps to output to account for gate
        gate_irreps = e3nn.Irreps(
            f'{output_irreps.num_irreps - output_irreps.count("0e")}x0e'  # type: ignore
        )
        output_irreps_with_gate = (output_irreps + gate_irreps).regroup()

        self.linear_2 = e3nn.flax.Linear(output_irreps_with_gate, tp_irreps)  # type: ignore
        self.skip = e3nn.flax.Linear(
            output_irreps_with_gate, input_irreps, force_irreps_out=True
        )

        if self.use_layer_norm:
            self.rms_layer_norm = RMSLayerNorm(
                irreps=e3nn.Irreps(output_irreps_with_gate),
                centering=False,  # default from Nequix
                std_balance_degrees=True,  # default from Nequix
            )

    def __call__(
        self,
        features: e3nn.IrrepsArray,  # (A, irreps_in)
        sph: e3nn.IrrepsArray,  # (num_edges, sh_dim)
        radial_basis: jax.Array,  # (num_edges, radial_basis_size)
        senders: jax.Array,  # (num_edges,) atom indices
        receivers: jax.Array,  # (num_edges,) atom indices
    ) -> e3nn.IrrepsArray:  # (A, irreps_out)
        output_irreps = e3nn.Irreps(self.output_irreps)
        messages = self.linear_1(features)[senders]  # (num_edges, irreps_in)
        messages = e3nn.tensor_product(
            messages,
            sph,
            filter_ir_out=output_irreps,  # type: ignore
        )  # (num_edges, tp_irreps)
        radial_message = self.radial_mlp(radial_basis)
        messages = messages * radial_message

        messages_agg = e3nn.scatter_sum(
            messages, dst=receivers, output_size=features.shape[0]
        ) / jnp.sqrt(self.avg_num_neighbors)

        skip = self.skip(features)
        out = self.linear_2(messages_agg) + skip

        if self.use_layer_norm:
            out = self.rms_layer_norm(out)

        return e3nn.gate(
            out,
            even_act=self.even_act,  # type: ignore[arg-type]
            odd_act=self.odd_act,  # type: ignore[arg-type]
            even_gate_act=self.even_gate_act,  # type: ignore[arg-type]
        )


class Nequix(BaseGNN):
    """
    Koker, T.; Smidt, T. 2025
    "Training a Foundation Model for Materials on a Budget"
    https://doi.org/10.48550/arXiv.2508.16067.

    Notes:
    - The original Equinox/Jraph Nequix predicted energies/forces/stress; EGXC only needs
      a scalar graph readout and per-atom features, so this module follows `BaseGNN`.
    - No per-species skip connetions
    """

    irreps_str: str = '128x0e + 128x1o + 128x2e + 128x3o'
    output_irreps_str: str = '16x0e + 16x1o + 16x2e + 16x3o'
    message_cutoff: float = 6.0
    layers: int = 4
    energy_graph_readout_hidden_dims: Tuple[int, ...] = (
        128,
        1,
    )  # additional hidden layer
    n_radial_basis: int = 8  # number of RBF used in the message passing layer
    init_graph_readout_to_zero: bool = True

    radial_mlp_size: int = 64
    radial_mlp_layers: int = 3
    mlp_init_scale: float = 4.0
    scalar_readout_bias: float = 0.0
    scalar_readout_std: float = 0.01  # DFT specific:10mEh / atom
    avg_num_neighbors: float = 20.0  # computed on QM9(7) for 6 Angstrom cutoff
    layer_norm: bool = False
    radial_basis_type: RBFType = 'trigonometric'  # deviate from Nequix default (bessel)

    @nn.compact
    def __call__(
        self,
        atom_features: e3nn.IrrepsArray,  # (A, Irreps[RBF, (l,m)])  with m,l as in Y_{l,m}
        atom_pos: FloatAx3,  # (A, 3)
        atom_mask: BoolA,  # (A,)
    ) -> Tuple[
        Float1, e3nn.IrrepsArray
    ]:  # energy, (A, Irreps[F_out, (l,m)])  with m,l as in Y_{l,m}
        # dense graph (all-to-all)
        A: int = atom_mask.shape[0]
        senders, receivers = jnp.nonzero(
            jnp.ones((A, A), dtype=bool), size=A**2
        )  # (num_edges,) each, containing atom indices
        atom_displacements = atom_pos[:, None] - atom_pos
        # mask out the padding atoms and self-messages by moving the atoms out of the cutoff
        mask = atom_mask[:, None] & atom_mask[None, :] & ~jnp.eye(A, dtype=bool)
        atom_displacements = jnp.where(
            mask[..., None],  # shape (A, A, 1)
            atom_displacements,  # shape (A, A, 3)
            self.message_cutoff * jnp.ones_like(atom_displacements),
        )
        atom_displacements = atom_displacements.reshape(-1, 3)
        vectors = e3nn.IrrepsArray('1o', atom_displacements)

        hidden_irreps = e3nn.Irreps(self.irreps_str)
        sph_irreps = e3nn.s2_irreps(hidden_irreps.lmax)
        sph = e3nn.spherical_harmonics(
            sph_irreps,
            vectors,
            normalize=True,
            normalization='component',
        )

        radial_basis_vals = radial_basis_values(
            safe_norm(atom_displacements, axis=-1),
            self.message_cutoff,
            self.n_radial_basis,
            self.radial_basis_type,
        ).reshape(-1, self.n_radial_basis)  # (num_edges, n_radial_basis)

        for i in range(self.layers):
            irreps_in = atom_features.irreps if i == 0 else hidden_irreps
            atom_features = _NequixLayerFlax(
                input_irreps=irreps_in,
                output_irreps=hidden_irreps,
                sph_irreps=sph_irreps,
                radial_basis_size=self.n_radial_basis,
                radial_mlp_size=self.radial_mlp_size,
                radial_mlp_layers=self.radial_mlp_layers,
                mlp_init_scale=self.mlp_init_scale,
                avg_num_neighbors=self.avg_num_neighbors,
                use_layer_norm=self.layer_norm,
            )(atom_features, sph, radial_basis_vals, senders, receivers)

        graph_readout = GraphReadout(  # TODO: check against reference implementation
            self.energy_graph_readout_hidden_dims, self.init_graph_readout_to_zero
        )(atom_features)

        graph_readout = (graph_readout * atom_mask[:, None]).sum()
        graph_readout = graph_readout * self.scalar_readout_std + self.scalar_readout_bias
        atom_features = e3nn.flax.Linear(e3nn.Irreps(self.output_irreps_str))(
            atom_features  # TODO: Ablate skip connections for these?
        )
        atom_features *= atom_mask[:, None]  # TODO: is this necessary?
        return graph_readout, atom_features
