"""
A pseudoscalar allegro 3d-encoder with node transformer paths
"""

import math

from e3nn import o3

import torch
from torch import nn

from coarsebind_public.mol_encoder.models.loose_modules.activations import NewGELU
from coarsebind_public.mol_encoder.models.loose_modules.stoich_encoding import (
    StoichiometryEncoder,
)
from coarsebind_public.mol_encoder.models.transformer.attention import (
    MaskedAttentionBlock,
    AttAgg,
)

from coarsebind_public.mol_encoder.models.encoder_3d.allegro_modules import (
    allegro_xy_oh,
    allegro_sph_edges,
    IrrepWeightLayer,
)
from coarsebind_public.mol_encoder.data.tokenizer.mol_graph import dense_to_sparse
from coarsebind_public.mol_encoder.models.encoder_3d._contract import Contracter
from coarsebind_public.mol_encoder.models.encoder_3d._linear import StridedLinear


def make_neighborlist_system_bath(
    x, node_mask, nonstatic, cutoff=torch.tensor(5.0, requires_grad=False)
):
    """
    A version of the usual neighborlist which kills all bath-bath edges
    because you have a system "of interest"

    - Makes the usual Is, Js, Ks sparsely indexing nodes, but Js must be
    partition == 0
    """
    n_batch = x.shape[0]
    n_node = x.shape[1]
    max_n_free = nonstatic.sum(1).max()
    d = torch.cdist(x[:, :max_n_free, :], x)
    nearby_indicies = (d < cutoff.to(x.device)).nonzero()

    Is_ = nearby_indicies[:, 0]
    Js_ = nearby_indicies[:, 1]
    Ks_ = nearby_indicies[:, 2]

    good_nodes = torch.logical_and(nonstatic[Is_, Js_], node_mask[Is_, Ks_])
    whole_mask = torch.logical_and(Js_ != Ks_, good_nodes)

    Is, Js, Ks = Is_[whole_mask], Js_[whole_mask], Ks_[whole_mask]
    return Is, Js, Ks, d[Is, Js, Ks]


class SwiGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim=-1)
        return torch.nn.functional.silu(gate) * x


class SwiGLUNet(nn.Module):
    def __init__(self, d_in, d_out, residual=False, dropout=0.0, use_weight_norm=False, bias=True):
        """
        10/25 - added dropout.
        """
        super().__init__()
        self.residual = residual
        self.net = nn.Sequential(
            nn.LayerNorm(d_in),
            torch.nn.Dropout(p=dropout),
            # should this one be weight-normed as well? (vs just the second)
            nn.Linear(d_in, 2 * d_out, bias=bias),
            SwiGLU(),
            nn.Linear(d_out, d_out, bias=bias),
        )

    def forward(self, x):
        if self.residual:
            return self.net(x) + x
        else:
            return self.net(x)


# I've never had good luck with
# learnable radials. So just hardcoding some sensible choices
def fourier_r_embedding(timesteps, num_basis: int, max_r=2000.0):
    """
    A fourier radial embedding with a few fixed frequencies for okay
    Shortrange and low res long range (~2000A)
    """
    assert num_basis % 4 == 0
    num_timescales = num_basis // 2
    inv_timescales0 = torch.linspace(
        0.001 / (max_r / 10.0) + 1e-3,
        2 * 3.14 / (max_r / 100.0),
        num_timescales // 2,
        device=timesteps.device,
    )
    inv_timescales1 = torch.linspace(
        0.001 / (max_r / 10.0) + 1e-3,
        2 * 3.14 / (max_r / 10.0),
        num_timescales // 4,
        device=timesteps.device,
    )
    inv_timescales2 = torch.linspace(
        0.001 / (max_r) + 1e-3,
        2 * 3.14 / (max_r),
        num_timescales // 4,
        device=timesteps.device,
    )
    inv_timescales = torch.cat([inv_timescales0, inv_timescales1, inv_timescales2])
    #     inv_timescales = np.logspace(.1/(max_r), 2*3.14/(max_r), num_timescales)
    emb = timesteps.to(timesteps.dtype).unsqueeze(-1) * inv_timescales.unsqueeze(0)  # (T, D/2)
    return torch.cat([emb.sin(), emb.cos()], dim=1)  # (T, D)


def poly_cutoff(x: torch.Tensor, factor: float, p: float = 6.0) -> torch.Tensor:
    x = x * factor
    out = 1.0
    out = out - (((p + 1.0) * (p + 2.0) / 2.0) * torch.pow(x, p))
    out = out + (p * (p + 2.0) * torch.pow(x, p + 1.0))
    out = out - ((p * (p + 1.0) / 2) * torch.pow(x, p + 2.0))
    return out * (x < 1.0)


def bessel_radial(x, num_basis=12):
    r_max = 8.0
    prefactor = 2.0 / r_max
    bessel_weights = (
        torch.linspace(start=1.0, end=num_basis, steps=num_basis, device=x.device, dtype=x.dtype)
        * torch.pi
    )
    n1 = num_basis // 3
    n2 = 2 * (num_basis // 3)

    X = x.unsqueeze(-1) + 1e-3
    r1 = bessel_weights[:n1] * X / (r_max + 1e-4)
    r2 = bessel_weights[n1:n2] * X / (r_max + 1e-4)
    r3 = bessel_weights[n2:] * X / (r_max / 2 + 1e-4)

    j0_terms = 2.0 * prefactor * (torch.sin(r1) / r1)
    sin_r2 = torch.sin(r2)
    cos_r2 = torch.cos(r2)
    j1_terms = 5.0 * prefactor * ((sin_r2 / (r2.pow(2.0))) - (cos_r2 / r2))
    sin_r3 = torch.sin(r3)
    cos_r3 = torch.cos(r3)
    j2_terms = (
        8.0 * prefactor * (1.0 / r3) * ((3.0 / (r3.pow(2.0)) - 1.0) * sin_r3 - (3 * cos_r3) / (r3))
    )
    return poly_cutoff(x, factor=1.0 / r_max).unsqueeze(-1) * torch.cat(
        [j0_terms, j1_terms, j2_terms], -1
    ).clamp(-1, 1)


class SelfAttention(nn.Module):
    """
    Simple Self-Attention
    """

    def __init__(self, dim_embed, nhead):
        super().__init__()
        assert dim_embed % nhead == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(dim_embed, 3 * dim_embed)
        # output projection
        self.c_proj = nn.Linear(dim_embed, dim_embed)
        self.n_head = nhead
        self.n_embd = dim_embed

    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = torch.nn.functional.softmax(att, dim=-1)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = (
            y.transpose(1, 2).contiguous().view(B, T, C)
        )  # re-assemble all head outputs side by side
        y = self.c_proj(y)
        return y


class AttentionBlock(nn.Module):
    """A _n-causal_ Self-Attention Block."""

    def __init__(self, dim_embed, nhead):
        super().__init__()
        self.ln_1 = nn.LayerNorm(dim_embed)
        self.attn = SelfAttention(dim_embed, nhead)
        self.ln_2 = nn.LayerNorm(dim_embed)
        self.mlpf = nn.Sequential(
            nn.Linear(dim_embed, 4 * dim_embed),
            NewGELU(),
            nn.Linear(4 * dim_embed, dim_embed),
        )

    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlpf(self.ln_2(x))
        return x


def gather_message(T, Is, Js, batch_size=0, max_n_atoms=0, vp=True):
    """
    A variance preserving message gather.
    """
    tore = torch.zeros((batch_size * max_n_atoms, T.shape[-1]), device=T.device, dtype=T.dtype)
    tore.index_add_(dim=0, index=Is * (max_n_atoms) + Js, source=T)
    if vp:
        tore_denom = torch.zeros(
            (batch_size * max_n_atoms, T.shape[-1]), device=T.device, dtype=T.dtype
        )
        tore_denom.index_add_(dim=0, index=Is * (max_n_atoms) + Js, source=torch.ones_like(T))
        return (tore / (1e-6 + tore_denom).sqrt()).reshape(batch_size, max_n_atoms, T.shape[-1])
    else:
        return tore.reshape(batch_size, max_n_atoms, T.shape[-1])


class AttentiveAllegroLayer(torch.nn.Module):
    def __init__(
        self,
        irreps_edge_sh,  # irrep of the edge features
        irreps_in,  # irrep of the previous layer
        irreps_out,  # irreps returned layer
        layer_idx: int = 0,
        cutoff: float = 5.0,  # cutoff radius for dynamic edges and SR radial.
        dim_node=64,
        dim_scalar: int = 512,  # dimension of the scalar even features
        dim_equivariant: int = 64,  # channel multiplicity of the vector features.
        dim_edge_emb: int = 32,  # dimension of edge categories.
        dim_bessel=12,
        dim_fourier=32,
        dim_global=64,
    ):
        """
        This version takes a separate set of known edges without cutoff.
        """
        super().__init__()

        self.irreps_edge_sh = irreps_edge_sh
        self.irreps_in = irreps_in
        self.irreps_out = irreps_out
        self.layer_idx = layer_idx
        self.cutoff = cutoff

        self.dim_node = dim_node
        self.dim_scalar = dim_scalar
        self.dim_equivariant = dim_equivariant
        self.dim_edge_emb = dim_edge_emb
        self.dim_bessel = dim_bessel
        self.dim_fourier = dim_fourier
        self.dim_global = dim_global

        # Modules are named corresponding to Fig. 1 of (https://arxiv.org/pdf/2204.05249.pdf)
        self.embed_mlp = SwiGLUNet(dim_scalar, len(o3.Irreps(irreps_edge_sh)) * dim_equivariant)

        # These are shared though.
        self.irrep_weight_layer = IrrepWeightLayer(irreps_edge_sh, dim_equivariant)

        # Compute the tensor product irreps and shapes.
        self.sh_irreps = o3.Irreps(
            [(self.dim_equivariant, (mir.ir.l, mir.ir.p)) for mir in irreps_edge_sh]
        )
        self.inv_tp_output_dim = 0
        tmp_i_out: int = 0
        instr = []
        full_out_irreps = []
        for i_out, (_, ir_out) in enumerate(self.irreps_out):
            for i_1, (_, ir_1) in enumerate(self.sh_irreps):
                for i_2, (_, ir_2) in enumerate(self.irreps_in):
                    if ir_out in ir_1 * ir_2:
                        if ir_out == o3.Irrep("0e"):
                            self.inv_tp_output_dim += 1
                        instr.append((i_1, i_2, tmp_i_out))
                        full_out_irreps.append((dim_equivariant, ir_out))
                        tmp_i_out += 1

        self.tp = Contracter(
            self.sh_irreps,
            self.irreps_in,
            full_out_irreps,
            instructions=instr,
            connection_mode="uuu",  # like embed initial edge = True.
            shared_weights=False,
            has_weight=False,
            pad_to_alignment=1,
            sparse_mode=None,
        )

        self.pre_latent_norm_scalar = torch.nn.LayerNorm(dim_scalar)
        self.pre_latent_norm_tp = torch.nn.LayerNorm(self.inv_tp_output_dim * self.dim_equivariant)
        self.final_norm = torch.nn.LayerNorm(dim_scalar)

        self.X_to_dY = SwiGLUNet(self.dim_scalar, self.dim_node)

        self.node_attention_block = MaskedAttentionBlock(self.dim_node, self.dim_node // 8)

        self.dim_two_body = (
            2 * self.dim_node
            + self.dim_bessel
            + self.dim_fourier
            + self.dim_edge_emb
            + self.dim_global
        )
        self.two_body = SwiGLUNet(self.dim_two_body, dim_scalar)

        self.latent_mlp = SwiGLUNet(
            dim_scalar + self.inv_tp_output_dim * self.dim_equivariant, dim_scalar
        )

        self.equivariant_linear = StridedLinear(
            full_out_irreps,
            [(dim_equivariant, ir) for _, ir in self.irreps_out],
            shared_weights=True,
            internal_weights=True,
            pad_to_alignment=1,
        )

    def forward(
        self,
        atoms,
        coords,
        node_is_free,  # whether a node is frozen.
        G,
        Y,  # global, node representations from previous layer.
        X,
        V,  # edge representations from the previous layer.
        Is,
        Js,
        Ks,
        Ds,
        edge_ys,
        edge_cutoffs,
        edge_embs,
    ):
        """
        All atom-gathers only occur onto nonfrozen atoms,
        and all nonfrozen atoms _must_ precede frozen atoms
        in the order of atoms. (to get massive savings on all scatters)

        [non-frozen]... [possibly many frozen atoms.]

        Args:
            atoms, coords : batch X max_n_atoms simple cartesian atoms.
            node_is_free : batch X max_n_atoms is this atom frozen or free?
            Y : batch X max_n_atoms X dim_scalar
            X : edges X scalar edge features
            V : edges X mul X irreps , vector features of edges.
            Is, Js, Ks, Ds: output of make_neighborlist_system_bath
        """
        # # Gather the message update to the invariant node rep
        batch_size = atoms.shape[0]
        num_free_nodes = node_is_free.sum(-1).max()
        assert (~torch.isnan(Y)).all()

        R_sr_embs = bessel_radial(Ds, self.dim_bessel)
        R_lr_embs = fourier_r_embedding(Ds, self.dim_fourier, max_r=2_000.0)

        # free nodes must precede all frozen nodes.
        dY_from_X = gather_message(
            self.X_to_dY(X) * (edge_cutoffs.unsqueeze(-1)),
            Is,
            Js,
            batch_size=batch_size,
            max_n_atoms=num_free_nodes,
            vp=True,
        )
        to_attend = Y[:, :num_free_nodes, :].clone()
        to_attend = to_attend + dY_from_X

        # TODO add rotary distance embedding to the attention here.
        # right now the swigluMLP is (hopefully) slurping that up.
        Y[:, :num_free_nodes, :] = Y[:, :num_free_nodes, :] + self.node_attention_block(
            to_attend[:, :num_free_nodes, :], atoms[:, :num_free_nodes] > 0
        )

        # This is pretty meaty.
        dX = self.two_body(
            torch.cat(
                [
                    Y[Is, Js],
                    Y[Is, Ks],
                    R_sr_embs,
                    R_lr_embs,
                    edge_embs,
                    G[Is],
                ],
                dim=-1,
            )
        ) * (edge_cutoffs.unsqueeze(-1))

        X = X + dX

        # Now we do an ordinary Allegro layer.
        w = (self.embed_mlp(X) * edge_cutoffs.unsqueeze(-1)).reshape(
            Is.shape[0], self.dim_equivariant, len(self.irreps_edge_sh)
        )  # edges X irreps*mul => edges X mul X irreps
        wy_ = self.irrep_weight_layer(w, edge_ys)  # Edges X dim_equivariant X irreps
        # gotta scatter_add this onto the Jth atom.
        atom_envs_ = torch.zeros(
            atoms.shape[0] * num_free_nodes,
            self.dim_equivariant,
            self.irreps_edge_sh.dim,
            dtype=coords.dtype,
            device=coords.device,
        )  # batch*atoms X mul X irreps
        denom = torch.zeros_like(atom_envs_)
        # Below is the summation over neighbors.
        # To make this smooth, we need to have this be smooth WRT #neighbors.
        atom_envs_.index_add_(
            dim=0, index=Is * num_free_nodes + Js, source=wy_
        )  # batch*atoms X mul X irreps
        denom.index_add_(
            dim=0, index=Is * num_free_nodes + Js, source=torch.ones_like(wy_)
        )  # batch*atoms X mul X irreps
        atom_envs = atom_envs_ / (denom + 1e-6).sqrt()

        # scatter this back out to edges X mul X irreps
        atom_envs_edge = atom_envs[Is * num_free_nodes + Js, :, :]
        DV = self.tp(atom_envs_edge, V)
        # Extract the scalar features which should be the first irrep. (check this?)
        DV_scalar = DV[:, :, : self.inv_tp_output_dim].reshape(DV.shape[0], -1)

        new_latent = self.latent_mlp(
            torch.cat(
                [
                    self.pre_latent_norm_scalar(X),
                    self.pre_latent_norm_tp(DV_scalar),
                ],
                dim=-1,
            )
        ) * edge_cutoffs.unsqueeze(-1)

        new_V = self.equivariant_linear(DV)  # *edge_cut_facs.unsqueeze(-1).unsqueeze(-1)
        # Allegro source concatenates previous scalars after layer 2? (lines 600-604)
        # which seems pretty high-D to me. maybe implement later.
        new_X = self.final_norm(new_latent + X)

        # DEBUG
        assert (~torch.isnan(Y)).all()
        assert (~torch.isnan(new_X)).all()
        assert (~torch.isnan(new_V)).all()

        return Y, new_X, new_V


class AttentiveAllegro(torch.nn.Module):
    def __init__(
        self,
        irreps_edge_sh_="1x0e+1x1o+1x2e",
        irreps_out_="1x0e+1x0o+1x1e+1x1o",
        dim_scalar=256,  # dimension of the scalar even features
        dim_equivariant=32,  # channel multiplicity of the vector features.
        dim_bessel=12,
        dim_fourier=32,
        dim_node=32,
        dim_global=256,
        layers_allegro=4,  # number of allegro-like tensor product layers
        cutoff=5.0,
        max_node_types=50,
        max_edge_types=50,
        edge_emb_dim=32,
        n_out_tokens=1,  # the final output will be dim_scalar * n_out_tokens.
        device=torch.device("cuda"),
    ):
        """
        Takes atoms, coords, nodes, edges
        and returns a [stop] xformer decoded vector
        """
        super().__init__()
        self.cutoff = cutoff
        self.dim_equivariant = dim_equivariant
        self.dim_scalar = dim_scalar
        self.dim_bessel = dim_bessel
        self.dim_global = dim_global
        # This now gets the graph information
        # in the form of miles' graph encoder.
        self.max_node_types = max_node_types
        self.max_edge_types = max_edge_types

        self.device = device

        irreps_edge_sh = o3.Irreps(irreps_edge_sh_)
        irreps_out = o3.Irreps(irreps_out_)

        self.sph_edges = allegro_sph_edges(irreps_edge_sh=irreps_edge_sh)
        self.node_embedding = allegro_xy_oh(device=self.device)

        self.dim_node = dim_node
        self.node_graph_emb = nn.Embedding(self.max_node_types, self.dim_node, padding_idx=0)
        self.node_emb = SwiGLUNet(
            self.node_embedding.out_features + self.dim_node, self.dim_node
        )  # does allegro_xy_oh + the graph types.

        self.dim_fourier = dim_fourier
        self.edge_emb_dim = edge_emb_dim
        self.edge_emb = nn.Embedding(self.max_edge_types, self.edge_emb_dim, padding_idx=0)

        self.irreps_edge_sh = irreps_edge_sh
        self.irreps_out = o3.Irreps(
            [(self.dim_equivariant, (mir.ir.l, mir.ir.p)) for mir in irreps_out]
        )
        assert self.irreps_out[0].ir == (
            0,
            1,
        )  # 1x0e is the first irrep for even scalar output
        assert self.irreps_out[1].ir == (
            0,
            -1,
        )  # 1x0o is the second irrep for psuedoscalar output

        # w of eq 15 of https://arxiv.org/pdf/2204.05249.pdf
        self.two_body = SwiGLUNet(
            self.dim_node * 2 + self.dim_bessel + self.dim_fourier + self.edge_emb_dim,
            self.dim_scalar,
        )
        self.embed_mlp = SwiGLUNet(self.dim_scalar, len(self.irreps_edge_sh) * self.dim_equivariant)

        self.irrep_weight_layer = IrrepWeightLayer(self.irreps_edge_sh, self.dim_equivariant)

        # Soon enough we need to prune paths overall
        # The path logic will go here.
        # for now just ignore the dead paths.
        self.allegro_layers = torch.nn.ModuleList(
            [
                AttentiveAllegroLayer(
                    irreps_edge_sh,
                    irreps_in=o3.Irreps(
                        [(dim_equivariant, (x.ir.l, x.ir.p)) for x in irreps_edge_sh]
                    ),
                    irreps_out=self.irreps_out,
                    layer_idx=0,
                    cutoff=self.cutoff,
                    dim_scalar=self.dim_scalar,  # dimension of the scalar even features
                    dim_equivariant=self.dim_equivariant,  # channel multiplicity of the vector features.
                    dim_node=self.dim_node,
                    dim_bessel=self.dim_bessel,
                    dim_fourier=self.dim_fourier,
                    dim_edge_emb=self.edge_emb_dim,
                    dim_global=dim_global,
                )
            ]
        )

        for k in range(1, layers_allegro):
            self.allegro_layers.append(
                AttentiveAllegroLayer(
                    irreps_edge_sh,
                    irreps_in=self.allegro_layers[-1].irreps_out,
                    irreps_out=self.irreps_out,  # Not quite pruned, but those are getting filled with zero.
                    layer_idx=k,
                    cutoff=self.cutoff,
                    dim_scalar=self.dim_scalar,  # dimension of the scalar even features
                    dim_equivariant=dim_equivariant,  # channel multiplicity of the vector features.
                    dim_node=self.dim_node,
                    dim_bessel=self.dim_bessel,
                    dim_fourier=self.dim_fourier,
                    dim_edge_emb=self.edge_emb_dim,
                    dim_global=dim_global,
                )
            )

        self.stoich_emb = StoichiometryEncoder(dim=dim_global)

        self.stop_proj = SwiGLUNet(
            self.dim_scalar + self.dim_node + self.dim_equivariant, self.dim_scalar
        )  # does allegro_xy_oh + the graph
        self.stop_decoder = AttAgg(dim_scalar, n_head=16, n_layers=2, n_out_tokens=n_out_tokens)

    def forward(self, atoms, coords, nodes, edges, apply_stop_decode=True):
        """
        atoms: batch X max_n_atoms batch of input atomic numbers.
                system (atoms with forces) should precede bath (no forces).
        coords: batch X max_n_atoms X 3
                coordinates in the same order as atoms.

        # this is a maturation of the frzn_edges of allegro_pocket_vector_field_6
        # which uses the node and edge types of the graph encoder.
        # edges = 0 is pad, edges = 1 is mask. Here they convey bond order, charge, isotope etc. information.
        # and keep symmetry with the 3d flow decoder.

        nodes: batch X max_n_atoms batch of node types
        edges: batch X max_n_atoms*(max_n_atoms-1)/2 batch of edges
        """
        assert atoms.dim() == 2
        assert (~torch.isnan(atoms)).all()
        assert (~torch.isnan(nodes)).all()

        nbatch = atoms.shape[0]
        natoms = atoms.shape[1]

        G = self.stoich_emb(atoms)

        atom_part = self.node_embedding(atoms)
        node_part = self.node_graph_emb(nodes)
        assert (~torch.isnan(atom_part)).all()
        assert (~torch.isnan(node_part)).all()

        Y0 = torch.cat([atom_part, node_part], -1)
        assert (~torch.isnan(Y0)).all()
        Y = self.node_emb(Y0)
        assert (~torch.isnan(Y)).all()

        node_mask = atoms > 0.0
        Is_, Js_, Ks_, _ = make_neighborlist_system_bath(
            coords,
            node_mask,
            node_mask,
            torch.tensor(self.cutoff, device=coords.device, requires_grad=False),
        )

        # convert the dense edges into sparse edges.
        fIs, fJs, fKs, fCs = dense_to_sparse(edges)

        # concatenate the frozen edges onto the dynamic edges.
        # Create the tensors needed to have force on the frozen edges.
        # frzn_edges.shape = batch X max_edges X 3 (atom 1 , atom 2, edge cat)

        # The dynamic edges will map to category one.
        Is = torch.cat([fIs, Is_], 0)
        Js = torch.cat([fJs, Js_], 0)
        Ks = torch.cat([fKs, Ks_], 0)
        Cs = torch.cat([fCs, torch.ones(Is_.shape[0], device=Is_.device, dtype=torch.long)], 0)

        edge_vectors = coords[Is, Js] - coords[Is, Ks]
        edge_ys = self.sph_edges(edge_vectors)  # edges X irreps
        Ds = torch.linalg.norm(edge_vectors, dim=-1).clamp(1e-6, 3000.0)

        # Create the multiplicative cutoffs which will be 1. for any frozen edges.
        # and otherwise dynamic
        nfrzn_edges = fIs.shape[0]
        cutoffs = torch.cat(
            [
                torch.ones(nfrzn_edges, device=Is.device, dtype=coords.dtype),
                poly_cutoff(Ds[nfrzn_edges:], 1 / self.cutoff),
            ],
            0,
        )

        edge_embs = self.edge_emb(Cs)
        R_lr_embs = fourier_r_embedding(Ds, self.dim_fourier, max_r=2_000.0)

        X = self.two_body(
            torch.cat(
                [
                    Y[Is, Js],
                    Y[Is, Ks],
                    bessel_radial(Ds.clamp(min=1e-6), self.dim_bessel),
                    R_lr_embs,
                    edge_embs,
                ],
                dim=-1,
            )
        )  # edges X dim_scalar

        # Perform eqs. 14,15
        w = self.embed_mlp(X).reshape(Is.shape[0], self.dim_equivariant, len(self.irreps_edge_sh))
        V = self.irrep_weight_layer(w, edge_ys)

        for i, layer in enumerate(self.allegro_layers):
            Y, X, V = layer(
                atoms,
                coords,
                node_mask,
                G,
                Y,
                X,
                V,
                Is,
                Js,
                Ks,
                Ds,
                edge_ys,
                cutoffs,
                edge_embs,
            )

        # TODO: Experiment with pulling out the node path, vs. edge path
        # and using [STOP] decodes.

        Xe_final = X
        Xo_final = V[:, :, 1]  # MLP has to have defined (even) parity,  now edges X mul

        atom_envs_even = torch.zeros(
            nbatch * atoms.shape[1],
            self.dim_scalar,
            dtype=coords.dtype,
            device=coords.device,
        )  # batch*atoms X mul
        atom_envs_even.index_add_(
            dim=0,
            index=Is * (atoms.shape[1]) + Js,
            source=(Xe_final * cutoffs.unsqueeze(-1)),
        ).reshape(
            nbatch, natoms, self.dim_scalar
        )  # batch*atoms X dim_scalar

        atom_envs_odd = torch.zeros(
            nbatch * atoms.shape[1],
            self.dim_equivariant,
            dtype=coords.dtype,
            device=coords.device,
        )  # batch*atoms X mul
        atom_envs_odd.index_add_(
            dim=0,
            index=Is * (atoms.shape[1]) + Js,
            source=(Xo_final * cutoffs.unsqueeze(-1)),
        ).reshape(
            nbatch, natoms, self.dim_equivariant
        )  # batch*atoms X dim_scalar

        H_even = atom_envs_even.reshape(nbatch, natoms, -1)
        H_odd = atom_envs_odd.reshape(nbatch, natoms, -1)

        if not apply_stop_decode:
            return Y, H_even, H_odd

        Hin = self.stop_proj(torch.cat([Y, H_even, H_odd], -1))
        return self.stop_decoder(Hin, mask_=(atoms > 0)), Y, H_even, H_odd
