
import math, contextlib
from dataclasses import dataclass
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from torch_scatter import scatter_add

from e3nn import o3
from e3nn.o3 import Irreps, FullyConnectedTensorProduct
from e3nn.nn import Gate, BatchNorm

from rdkit import Chem
from scipy import special, optimize
import numpy as np


# --------------------------
# Device / AMP helpers
# --------------------------
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def no_amp():
    """Disable AMP in a block (used for hyperbolic math)."""
    if hasattr(torch, "amp") and torch.cuda.is_available():
        return torch.amp.autocast('cuda', enabled=False)
    return contextlib.nullcontext()


# --------------------------
# SO(3) helpers
# --------------------------
def _hat(v: Tensor) -> Tensor:
    x, y, z = v[..., 0], v[..., 1], v[..., 2]
    O = torch.zeros_like(x)
    return torch.stack([
        torch.stack([ O, -z,  y], dim=-1),
        torch.stack([ z,  O, -x], dim=-1),
        torch.stack([-y,  x,  O], dim=-1)], dim=-2)

def so3_exp(omega: Tensor, step: float = 1.0) -> Tensor:
    """Rodrigues' formula; keeps norm shape (...,1) to avoid broadcast bugs."""
    th = omega.norm(dim=-1, keepdim=True).clamp(min=1e-12)        # (...,1)
    axis = omega / th                                             # (...,3)
    K = _hat(axis)                                                # (...,3,3)
    th = th * step                                                # (...,1)
    I = torch.eye(3, device=omega.device, dtype=omega.dtype)
    return I + torch.sin(th)[..., None] * K + (1 - torch.cos(th))[..., None] * (K @ K)

def reorthonormalize(R: Tensor) -> Tensor:
    x = F.normalize(R[..., :, 0], dim=-1)
    y = R[..., :, 1] - (x * R[..., :, 1]).sum(-1, keepdim=True) * x
    y = F.normalize(y, dim=-1)
    z = torch.cross(x, y, dim=-1)
    return torch.stack([x, y, z], dim=-1)


# --------------------------
# Hyperboloid H^d (curvature -1)
# --------------------------
def lorentz_dot(u: Tensor, v: Tensor) -> Tensor:
    return -u[..., 0] * v[..., 0] + (u[..., 1:] * v[..., 1:]).sum(dim=-1)

def h_dist(x: Tensor, y: Tensor) -> Tensor:
    with no_amp():
        a = -lorentz_dot(x, y).clamp(min=1.0 + 1e-7)
        return torch.acosh(a)

def project_to_hyperboloid(x: Tensor) -> Tensor:
    with no_amp():
        x0 = torch.sqrt(1.0 + (x * x).sum(dim=-1, keepdim=True))
        return torch.cat([x0, x], dim=-1)

def h_exp(x: Tensor, v: Tensor) -> Tensor:
    with no_amp():
        vn = torch.sqrt(torch.clamp((v[..., 1:] * v[..., 1:]).sum(-1) - v[..., 0] * v[..., 0], min=1e-12))
        c1 = torch.cosh(vn)[..., None]; c2 = torch.sinh(vn)[..., None] / (vn[..., None] + 1e-12)
        return c1 * x + c2 * v

def h_log(x: Tensor, y: Tensor) -> Tensor:
    with no_amp():
        a = -lorentz_dot(x, y).clamp(min=1.0 + 1e-7)
        d = torch.acosh(a)
        u = y + a[..., None] * x
        un = torch.sqrt(torch.clamp(a * a - 1.0, min=1e-12))[..., None]
        return (d[..., None] / (un + 1e-12)) * u

def h_barycenter(points: Tensor, w: Tensor, iters: int = 6) -> Tensor:
    with no_amp():
        x = project_to_hyperboloid((w[:, None] * points[..., 1:]).sum(dim=0) / (w.sum() + 1e-12))
        for _ in range(iters):
            v = (w[:, None] * h_log(x, points)).sum(dim=0) / (w.sum() + 1e-12)
            x = h_exp(x, 0.5 * v)
        return x


# --------------------------
# Exact spherical-Bessel radial basis
# --------------------------
# --- bugfix: provide spherical_bessel_zeros(l, N) with SciPy fallback ---

def spherical_bessel_zeros(l: int, n_zeros: int) -> torch.Tensor:
    """
    Return first n_zeros roots β_{l,n} of the spherical Bessel j_l(x).
    Uses SciPy's exact implementation when available; otherwise brackets and refines with Brent.
    """
    # Fast path if SciPy provides it
    if hasattr(special, "spherical_jn_zeros"):
        betas = special.spherical_jn_zeros(l, n_zeros)
        return torch.from_numpy(np.asarray(betas, dtype=np.float64))

    # Robust fallback: scan for sign changes and root-find with Brent
    def f(z: float) -> float:
        return float(special.spherical_jn(l, z))

    betas: List[float] = []
    step = math.pi / 2.0
    x_prev = 1e-8
    f_prev = f(x_prev)
    max_x = 1e4
    while len(betas) < n_zeros and x_prev < max_x:
        x_next = x_prev + step
        f_next = f(x_next)
        # sign change -> root in (x_prev, x_next)
        if f_prev == 0.0:
            betas.append(x_prev)
        elif f_prev * f_next < 0.0:
            root = optimize.brentq(lambda t: special.spherical_jn(l, t), x_prev, x_next, maxiter=200)
            betas.append(root)
            # after a root, reset bracket starting from root to avoid missing closely spaced ones for small l
            x_prev = root
            f_prev = f(x_prev + 1e-12)  # perturb to avoid zero exactly
            continue
        x_prev, f_prev = x_next, f_next

        # If we fail to collect enough, reduce step adaptively
        if x_prev > 200 and len(betas) < n_zeros // 2:
            step = math.pi / 3.0
        if x_prev > 1000 and len(betas) < n_zeros:
            step = math.pi / 4.0

    if len(betas) < n_zeros:
        raise RuntimeError(f"Could not bracket {n_zeros} spherical Bessel zeros for l={l}. Found {len(betas)}.")
    return torch.from_numpy(np.asarray(betas[:n_zeros], dtype=np.float64))


def _odd_double_factorial(n: int) -> int:
    """(2n+1)!! as an integer (used for small-x series of j_l)."""
    x = 1
    for k in range(n + 1):
        x *= (2 * k + 1)
    return x

def _sph_bessel(l: int, x: torch.Tensor) -> torch.Tensor:
    """
    Vectorized, numerically safe spherical Bessel j_l(x).
    Uses small-x series (up to O(x^2) correction) and stable closed forms / upward recursion otherwise.
    """
    x = x
    ax = x.abs()
    small = ax < 1e-4  # threshold for series use
    x2 = x * x

    # small-x series: j_l(x) ≈ x^l / (2l+1)!! * (1 - x^2 / (2(2l+3)))
    if l >= 0:
        denom = float(_odd_double_factorial(l))  # scalar
        series = (x ** l) / denom * (1.0 - x2 / (2.0 * (2 * l + 3.0)))
    else:
        raise ValueError("l must be nonnegative")

    # closed-form / recursion for non-small x
    if l == 0:
        large = torch.sin(x) / x
    elif l == 1:
        large = torch.sin(x) / (x2) - torch.cos(x) / x
    elif l == 2:
        s, c = torch.sin(x), torch.cos(x)
        x3 = x2 * x
        large = (3.0 / x3 - 1.0 / x) * s - 3.0 * c / x2
    else:
        # Start from j0, j1 with closed forms, then recur up to l
        j0 = torch.sin(x) / x
        j1 = torch.sin(x) / (x2) - torch.cos(x) / x
        jm2, jm1 = j0, j1
        # For elements that are small, the division by x is unstable; we'll override them by series later via torch.where.
        for n in range(1, l):
            jp = (2 * n + 1) * jm1 / x - jm2
            jm2, jm1 = jm1, jp
        large = jm1

    # Blend series for small |x| and closed-form/recursion elsewhere
    out = torch.where(small, series, large)
    # Handle exact x=0 cases explicitly for l=0 to ensure j0(0)=1 (already covered by series), others j_l(0)=0 (series does that too)
    return out


class RadialBesselExact(nn.Module):
    def __init__(self, L: int, N: int, r_cut: float):
        super().__init__()
        self.L, self.N, self.r_cut = L, N, r_cut
        # exact zeros (computed once)
        for l in range(L + 1):
            roots = spherical_bessel_zeros(l, N).to(dtype=torch.float64)
            self.register_buffer(f"roots_{l}", roots)
        self.mix = nn.ParameterDict({str(l): nn.Parameter(torch.randn(N, 1) * 0.05) for l in range(L + 1)})

    def forward(self, d: torch.Tensor) -> List[torch.Tensor]:
        # Smooth cosine envelope, clipped to [0, 1]
        env = (0.5 * (torch.cos(math.pi * d / self.r_cut).clamp(min=-1.0, max=1.0) + 1.0)).clamp(min=0.0)
        out = []
        x = d[:, None] / (self.r_cut + 1e-12)            # (E,1)
        for l in range(self.L + 1):
            z = getattr(self, f"roots_{l}")[None, :].to(d.device, d.dtype)  # (1,N)
            X = z * x                                      # (E,N)
            jl = _sph_bessel(l, X)                         # (E,N)
            c = (jl @ self.mix[str(l)]) * env[:, None]     # (E,1)
            out.append(c.nan_to_num(0.0, posinf=0.0, neginf=0.0))  # safety
        return out


# --------------------------
# Geometry helpers: torsion (even features downstream)
# --------------------------
def dihedral(p0, p1, p2, p3, eps: float = 1e-8):
    b0 = p1 - p0; b1 = p2 - p1; b2 = p3 - p2
    b1n = b1 / (b1.norm(dim=-1, keepdim=True).clamp(min=eps))
    v = b0 - (b0 * b1n).sum(-1, keepdim=True) * b1n
    w = b2 - (b2 * b1n).sum(-1, keepdim=True) * b1n
    x = (v * w).sum(-1)
    y = (torch.cross(b1n, v, dim=-1) * w).sum(-1)
    return torch.atan2(y, x)


# --------------------------
# Proper Junction-Tree (JT-VAE style)
# --------------------------
def build_jt_proper(mol: Chem.Mol) -> Tuple[List[List[int]], List[Tuple[int, int]]]:
    rings = [sorted(list(r)) for r in Chem.GetSymmSSSR(mol)]
    ring_sets = [set(r) for r in rings]
    ring_bonds = set()
    for r in ring_sets:
        for b in mol.GetBonds():
            a, c = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
            if a in r and c in r:
                ring_bonds.add(tuple(sorted((a, c))))
    cliques = [r for r in rings]
    for b in mol.GetBonds():
        a, c = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        e = tuple(sorted((a, c)))
        if e not in ring_bonds:
            cliques.append([a, c])

    if len(cliques) == 0:
        return [], []

    edges_w = []
    sets = [set(c) for c in cliques]
    for i in range(len(cliques)):
        Si = sets[i]
        for j in range(i + 1, len(cliques)):
            w = len(Si.intersection(sets[j]))
            if w > 0:
                edges_w.append((i, j, w))

    parent = list(range(len(cliques)))
    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]
            x = parent[x]
        return x
    def union(x, y):
        parent[find(x)] = find(y)

    edges_w.sort(key=lambda t: -t[2])
    jt_edges = []
    for u, v, w in edges_w:
        ru, rv = find(u), find(v)
        if ru != rv:
            union(ru, rv); jt_edges.append((u, v))

    reps = {}
    for i in range(len(cliques)):
        r = find(i); reps.setdefault(r, []).append(i)
    comps = list(reps.values())
    for k in range(1, len(comps)):
        jt_edges.append((comps[0][0], comps[k][0]))

    return cliques, jt_edges


# --------------------------
# Edge encoder: real SH(L) * radial (no internal normalization)
# --------------------------
class EdgeEncoder(nn.Module):
    def __init__(self, L: int, radial: RadialBesselExact):
        super().__init__()
        self.L = L
        self.radial = radial
        self.irreps = o3.Irreps.spherical_harmonics(L)

    def forward(self, rhat: Tensor, d: Tensor) -> Tensor:
        # We pass unit vectors and set normalize=False to avoid divide-by-zero inside e3nn.
        Y = o3.spherical_harmonics(list(range(self.L + 1)), rhat,
                                   normalize=False, normalization='component')  # (E, sum (2l+1))
        coeffs = self.radial(d)  # list of (E,1)
        blocks, idx = [], 0
        for l in range(self.L + 1):
            dim = 2 * l + 1
            blk = Y[:, idx: idx + dim] * coeffs[l]
            blocks.append(blk)
            idx += dim
        eattr = torch.cat(blocks, dim=-1)
        # Last-resort guard: keep the pipeline finite
        return torch.nan_to_num(eattr, nan=0.0, posinf=0.0, neginf=0.0)


# --------------------------
# Layer with full e3nn Gate
# --------------------------
class EquivariantLayer(nn.Module):
    def __init__(self, irreps_node: Irreps, irreps_edge: Irreps,
                 mul0: int, mul1: int, mul2: int):
        super().__init__()
        self.irreps_node = irreps_node
        self.irreps_edge = irreps_edge

        self.tp = FullyConnectedTensorProduct(irreps_node, irreps_edge, irreps_node)
        self.lin_self = o3.Linear(irreps_node, irreps_node)

        irreps_scalars = Irreps(f"{mul0}x0e")
        irreps_gates   = Irreps(f"{mul1 + mul2}x0e")
        irreps_gated   = Irreps(f"{mul1}x1o + {mul2}x2e")
        self.pre = o3.Linear(irreps_node, irreps_scalars + irreps_gates + irreps_gated)

        self.gate = Gate(irreps_scalars, [F.silu],
                         irreps_gates,   [torch.sigmoid],
                         irreps_gated)

        self.post = o3.Linear(self.gate.irreps_out, irreps_node)
        self.bn = BatchNorm(irreps_node, eps=1e-5)

    def forward(self, x: Tensor, src: Tensor, dst: Tensor, eattr: Tensor, alpha: Tensor):
        m_ij = self.tp(x[src], eattr) * alpha
        m = scatter_add(m_ij, dst, dim=0, dim_size=x.size(0))
        x = self.lin_self(x) + m
        x = self.pre(x)
        x = self.gate(x)
        x = self.post(x)
        x = self.bn(x)
        return x, m_ij


# --------------------------
# Whole model
# --------------------------
@dataclass
class ModelCfg:
    L: int = 2
    Nrad: int = 8
    r_cut: float = 5.0
    d_h: int = 16
    n_layers: int = 4
    mul_l0: int = 32
    mul_l1: int = 16
    mul_l2: int = 8
    so3_step: float = 0.2
    hyp_step: float = 0.2
    use_hyp_atoms: bool = True
    use_hyp_scaffolds: bool = True
    use_jt_invariants: bool = True

class ProductManifoldGNN(nn.Module):
    def __init__(self, cfg: ModelCfg, num_z_embeddings: int):
        super().__init__()
        self.cfg = cfg
        self.irreps_node = Irreps(f"{cfg.mul_l0}x0e + {cfg.mul_l1}x1o + {cfg.mul_l2}x2e")
        self.irreps_edge = o3.Irreps.spherical_harmonics(cfg.L)

        # Precompute ℓ=1 slices to aggregate vectors robustly
        self.l1_slices: List[Tuple[int, slice]] = []
        for mir, sl in zip(self.irreps_node, self.irreps_node.slices()):
            if mir.ir.l == 1:
                self.l1_slices.append((mir.mul, sl))
        if not self.l1_slices:
            raise RuntimeError("No ℓ=1 channels present; cannot drive SO(3) updates.")

        self.atom_emb = nn.Embedding(num_z_embeddings, cfg.mul_l0)
        self.radial = RadialBesselExact(cfg.L, cfg.Nrad, cfg.r_cut)
        self.edge_enc = EdgeEncoder(cfg.L, self.radial)

        self.layers = nn.ModuleList([
            EquivariantLayer(self.irreps_node, self.irreps_edge,
                             cfg.mul_l0, cfg.mul_l1, cfg.mul_l2)
            for _ in range(cfg.n_layers)
        ])

        gate_in_dim = 3 \
            + (1 if cfg.use_hyp_atoms else 0) \
            + (1 if cfg.use_hyp_scaffolds else 0) \
            + (1 if cfg.use_jt_invariants else 0)
        self.alpha_mlp = nn.Sequential(
            nn.Linear(gate_in_dim, 64), nn.SiLU(),
            nn.Linear(64, 64), nn.SiLU(),
            nn.Linear(64, 1)
        )

        readout_in = 3 \
            + (1 if cfg.use_hyp_atoms else 0) \
            + (1 if cfg.use_hyp_scaffolds else 0) \
            + cfg.mul_l0
        self.readout = nn.Sequential(
            nn.Linear(readout_in, 128), nn.SiLU(),
            nn.Linear(128, 64), nn.SiLU(),
            nn.Linear(64, 1)
        )

    def _init_latents(self, n: int, dev):
        with no_amp():
            R = torch.eye(3, device=dev).repeat(n, 1, 1)
            hA = project_to_hyperboloid(torch.zeros(n, self.cfg.d_h, device=dev))
        return R, hA

    def _jt_init(self, cliques: List[List[int]], hA: Tensor, dev) -> Tensor:
        if len(cliques) == 0:
            return torch.zeros(0, self.cfg.d_h + 1, device=dev)
        outs = []
        for S in cliques:
            if self.cfg.use_hyp_atoms and hA is not None and len(S) > 0:
                pts = hA[S]; w = torch.ones(pts.size(0), device=dev)
                outs.append(h_barycenter(pts, w))
            else:
                outs.append(project_to_hyperboloid(torch.zeros(self.cfg.d_h, device=dev)))
        return torch.stack(outs, dim=0)

    def forward(self, data):
        pos, z, edge_index, batch = data.pos, data.z, data.edge_index, data.batch
        dev = pos.device
        B = int(batch.max().item()) + 1

        # Node features (irreps vector)
        scalars0 = self.atom_emb(z.long())
        zeros = pos.new_zeros(scalars0.size(0), self.cfg.mul_l1*3 + self.cfg.mul_l2*5)
        x_all = torch.cat([scalars0, zeros], dim=-1)

        # Edge geometry (safe unit vectors for SH)
        src, dst = edge_index
        rij = pos[src] - pos[dst]
        dij = rij.norm(dim=-1)
        eps = 1e-8
        rhat = rij / (dij[:, None] + eps)
        # For zero-length edges (rare), pick a fixed unit direction
        zmask = dij <= eps
        if zmask.any():
            rhat[zmask] = 0.0
            rhat[zmask, 0] = 1.0
        eattr = self.edge_enc(rhat, dij.clamp(min=0.0))

        # Even torsion proxies
        cos_phi = torch.zeros_like(dij); sin2_phi = torch.zeros_like(dij)
        N = pos.size(0)
        nbh_dst = [[] for _ in range(N)]
        nbh_src = [[] for _ in range(N)]
        for j, i in zip(src.tolist(), dst.tolist()):
            nbh_dst[i].append(j); nbh_src[j].append(i)
        for e in range(src.numel()):
            j, i = int(src[e]), int(dst[e])
            k_cand = [u for u in nbh_dst[i] if u != j]
            l_cand = [u for u in nbh_src[j] if u != i]
            if not k_cand or not l_cand: continue
            k, l = k_cand[0], l_cand[0]
            phi = dihedral(pos[k], pos[i], pos[j], pos[l])
            cos_phi[e] = torch.cos(phi); s = torch.sin(phi); sin2_phi[e] = s * s

        y_pred = pos.new_zeros(B)
        for g in range(B):
            mask = (batch == g)
            idx = torch.nonzero(mask, as_tuple=False).view(-1)
            if idx.numel() == 0:
                continue
            n = idx.numel()
            offset = int(idx[0])
            emask = mask[dst] & mask[src]
            eidx = torch.nonzero(emask, as_tuple=False).view(-1)
            s_loc = src[eidx] - offset
            d_loc = dst[eidx] - offset

            xg = x_all[idx]
            eattr_g = eattr[eidx]
            cos_g = cos_phi[eidx][:, None]; sin2_g = sin2_phi[eidx][:, None]
            c0 = self.radial(dij[eidx])[0]  # (E,1)

            # Build JT
            mol = Chem.RWMol()
            for Z in z[idx].tolist(): mol.AddAtom(Chem.Atom(int(Z)))
            added = set()
            for s, t in zip(s_loc.tolist(), d_loc.tolist()):
                a, b = min(s, t), max(s, t)
                if (a, b) in added: continue
                mol.AddBond(a, b, Chem.BondType.SINGLE)
                added.add((a, b))
            mol = mol.GetMol()

            cliques, jt_edges = build_jt_proper(mol)
            if len(cliques) == 0:
                cliques, jt_edges = [[i for i in range(n)]], []
            atom2u = [-1] * n
            for u, S in enumerate(cliques):
                for a in S:
                    if atom2u[a] == -1: atom2u[a] = u
            for i_ in range(n):
                if atom2u[i_] == -1:
                    cliques.append([i_]); atom2u[i_] = len(cliques) - 1
                    if len(cliques) > 1: jt_edges.append((0, len(cliques) - 1))
            atom2u = torch.tensor(atom2u, device=dev)
            size_u = torch.tensor([len(S) for S in cliques], device=dev, dtype=torch.float32)

            # Latents
            if self.cfg.use_hyp_atoms:
                R_i, hA_i = self._init_latents(n, dev)
            else:
                R_i = torch.eye(3, device=dev).repeat(n, 1, 1)
                hA_i = None
            hT_u = self._jt_init(cliques, hA_i, dev) if self.cfg.use_hyp_scaffolds else None

            # Layers
            for layer in self.layers:
                # Gate α features (hyperbolic distances computed in FP32, then cast)
                alpha_feats = [c0, cos_g, sin2_g]
                if self.cfg.use_hyp_atoms:
                    hdist_A = h_dist(hA_i[d_loc], hA_i[s_loc]).unsqueeze(-1).to(xg.dtype)
                    alpha_feats.append(hdist_A)
                if self.cfg.use_jt_invariants:
                    size_dst = size_u[atom2u[d_loc]].unsqueeze(-1).to(xg.dtype)
                    alpha_feats.append(size_dst)
                if self.cfg.use_hyp_scaffolds and hT_u is not None and hT_u.numel() > 0:
                    u_dst = atom2u[d_loc]; u_src = atom2u[s_loc]
                    hdist_T = h_dist(hT_u[u_dst], hT_u[u_src]).unsqueeze(-1).to(xg.dtype)
                    alpha_feats.append(hdist_T)

                alpha_in = torch.cat(alpha_feats, dim=-1)
                alpha = torch.sigmoid(self.alpha_mlp(alpha_in))
                alpha = torch.nan_to_num(alpha, nan=0.0, posinf=1.0, neginf=0.0)

                # Equivariant update
                xg, m_ij = layer(xg, s_loc, d_loc, eattr_g, alpha)

                # SO(3) update from ℓ=1 (sum over multiplicities)
                node_msgs = scatter_add(m_ij, d_loc, dim=0, dim_size=n)
                M_R = torch.zeros(n, 3, device=dev, dtype=node_msgs.dtype)
                for mul, sl in self.l1_slices:
                    chunk = node_msgs[:, sl].view(n, mul, 3).sum(dim=1)
                    M_R = M_R + chunk
                R_i = torch.einsum("nij,njk->nik", R_i, so3_exp(M_R, step=self.cfg.so3_step))
                R_i = reorthonormalize(R_i)

                # Hyperbolic atom update (FP32)
                if self.cfg.use_hyp_atoms:
                    logs = h_log(hA_i[d_loc], hA_i[s_loc])
                    V = scatter_add(alpha * logs, d_loc, dim=0, dim_size=n)
                    hA_i = h_exp(hA_i, self.cfg.hyp_step * V)

                # Hyperbolic JT update (FP32)
                if self.cfg.use_hyp_scaffolds and hT_u is not None and len(jt_edges) > 0:
                    Vt = torch.zeros_like(hT_u)
                    for (u, v) in jt_edges:
                        Vt[u] = Vt[u] + h_log(hT_u[u], hT_u[v])
                        Vt[v] = Vt[v] + h_log(hT_u[v], hT_u[u])
                    hT_u = h_exp(hT_u, 0.5 * self.cfg.hyp_step * Vt)

            # E(3)-invariant readout (hyperbolic pieces FP32, then cast)
            s0 = xg[:, :self.cfg.mul_l0]
            v1 = xg[:, self.cfg.mul_l0:self.cfg.mul_l0 + self.cfg.mul_l1*3]
            q2 = xg[:, self.cfg.mul_l0 + self.cfg.mul_l1*3:]
            inv0 = (s0 * s0).sum(); inv1 = (v1 * v1).sum(); inv2 = (q2 * q2).sum()
            parts = [inv0, inv1, inv2]
            if self.cfg.use_hyp_atoms and hA_i is not None:
                hA_bar = h_barycenter(hA_i, torch.ones(n, device=dev))
                parts.append(h_dist(hA_i, hA_bar).mean())
            if self.cfg.use_hyp_scaffolds and hT_u is not None and hT_u.numel() > 0:
                hT_bar = h_barycenter(hT_u, torch.ones(hT_u.size(0), device=dev))
                parts.append(h_dist(hT_u, hT_bar).mean())
            read_scalars = torch.stack(parts, dim=0)
            read_scalars = torch.nan_to_num(read_scalars, nan=0.0, posinf=0.0, neginf=0.0).to(s0.dtype)
            read = torch.cat([read_scalars, s0.mean(dim=0)], dim=0)
            y_pred[g] = self.readout(read.unsqueeze(0)).squeeze(0)

        return y_pred