import functools
from typing import Callable, Sequence

import flax.linen as nn
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
from jax.scipy.linalg import block_diag
from nix import pmean_if_pmap
from nix.linalg import (
    skewsymmetric_quadratic,
    slog_pfaffian,
    slog_pfaffian_skewsymmetric_quadratic,
)
from nix.optax_ext import prodigy

from globe.nn import (
    ParamTree,
    ReparametrizedModule,
    block,
)
from globe.nn.parameters import ParamSpec, ParamType, inverse_softplus
from globe.systems.element import CORE_OFFSETS, MAX_CORE_ORB, VALENCY
from globe.utils import (
    chain,
    ema_make,
    ema_update,
    ema_value,
    flatten,
    itemgetter,
    np_segment_sum,
    tree_generator_zip,
)
from globe.utils.config import (
    SystemConfigs,
    group_by_config,
    inverse_group_idx,
    unique_configs,
)


def isotropic_envelope(
    x: jax.Array, sigma: jax.Array, pi: jax.Array, pi_scale: jax.Array
) -> jax.Array:
    """
    Computes the isotropic envelope of the orbitals.

    Args:
    - x (Array): The input tensor of shape (..., n_elec, n_nuc, 4) containing the electron-nucleus distances.
    - sigma (Array): The tensor of shape (n_nuc, n_det) containing the sigma parameters for each nucleus and determinant.
    - pi (Array): The tensor of shape (n_nuc, n_det) containing the pi parameters for each nucleus and determinant.
    - pi_scale (Array): The tensor of shape (n_nuc, n_det) containing the pi scale parameters for each nucleus and determinant.

    Returns:
    - The tensor of shape (..., n_elec, n_det) containing the isotropic envelope of the orbitals.
    """
    # We must reshape here because we get sigma and pi as (n_nuc*n_orbs, n_det) from the GNN
    sigma = sigma.reshape(x.shape[-2], -1)
    pi = pi.reshape(x.shape[-2], -1)
    pi_scale = pi_scale.reshape(x.shape[-2], -1)
    pi = pi * pi_scale
    # sum_m exp(- nmd * md) * md -> nd
    return jnp.sum(jnp.exp(-x[..., -1:] * sigma) * pi, axis=-2)


def group_orbital_params(params: ParamTree, config: SystemConfigs) -> ParamTree:
    """
    Groups the orbital parameters by configuration.

    Args:
    - params: The parameters of the orbitals.
    - config: The configuration of the system.

    Returns:
    - The parameters of the orbitals grouped by configuration.
    """
    spins = np.array(config.spins)

    def group_params_by_config(param):
        if param.shape[0] == sum(map(len, config.charges)):  # nuc
            return group_by_config(config, param, lambda s, c: len(c))
        elif param.shape[0] == sum(len(cs) ** 2 for cs in config.charges):  # nuc nuc
            return group_by_config(config, param, lambda s, c: len(c) ** 2)
        elif param.shape[0] == sum(
            len(cs) ** 3 for cs in config.charges
        ):  # nuc nuc nuc
            return group_by_config(config, param, lambda s, c: len(c) ** 3)
        elif param.shape[0] == spins.max(-1).sum():  # electrons
            return group_by_config(config, param, lambda s, c: max(s))
        else:  # electron orbital
            return group_by_config(config, param, lambda s, c: max(s) * len(c))

    return jtu.tree_map(group_params_by_config, params)


def eval_orbitals(
    orbital_fn,
    params: ParamTree,
    h_one: jax.Array,
    r_im: jax.Array,
    config: SystemConfigs,
    nuclei: jax.Array | None = None,
) -> list[tuple[jax.Array, ...]]:
    """
    Evaluates the orbitals.

    Args:
    - orbital_fn: The orbital function.
    - params: The parameters of the orbitals.
    - h_one: The one-electron features.
    - r_im: The electron-nucleus distances.
    - config: The configuration of the system.

    Returns:
    - The evaluated orbitals.
    """
    result = []
    if nuclei is None:
        for h, (r, (spins, charges)), param in tree_generator_zip(
            group_by_config(config, h_one, lambda s, c: np.sum(s)),
            group_by_config(
                config, r_im, lambda s, c: np.sum(s) * len(c), return_config=True
            ),
            group_orbital_params(params, config),
        ):
            result.append(orbital_fn(param, h, r, spins, charges))
    else:
        for h, nuc, (r, (spins, charges)), param in tree_generator_zip(
            group_by_config(config, h_one, lambda s, c: np.sum(s)),
            group_by_config(config, nuclei, lambda s, c: len(c)),
            group_by_config(
                config, r_im, lambda s, c: np.sum(s) * len(c), return_config=True
            ),
            group_orbital_params(params, config),
        ):
            result.append(orbital_fn(param, h, r, spins, charges, nuc))
    return result


def make_orbital_fn(
    orbital_fn: Callable[..., jax.Array],
    shared_orbitals: bool,
    full_det: bool,
    init_fn=None,
) -> Callable[..., tuple[jax.Array, ...]]:
    """
    Constructs a function that computes the orbital matrix.

    Args:
    - orbital_fn: The orbital function.
    - shared_orbitals: Whether the orbitals are shared.
    - full_det: Whether the full determinant is used.
    - init_fn: A function to initialize orbital function parameters.

    Returns:
    - A function that evaluates the orbital matrix.
    """

    @functools.partial(jax.vmap, in_axes=(0, 0, 0, None, None))
    def _orbitals(params, h_one, r_im, spins, charges):
        args = init_fn() if init_fn is not None else {}
        n_nuc = len(charges)
        n_elec = h_one.shape[0]
        r_im = r_im.reshape(n_elec, n_nuc, -1)

        if shared_orbitals:
            uu, dd = jnp.split(
                orbital_fn(h_one, r_im, **params['eq'], **args).reshape(
                    n_elec, max(spins), -1
                ),
                spins[:1],
                axis=0,
            )
            if full_det:
                ud, du = jnp.split(
                    orbital_fn(h_one, r_im, **params['neq'], **args).reshape(
                        n_elec, max(spins), -1
                    ),
                    spins[:1],
                    axis=0,
                )
        else:
            h_up, h_down = jnp.split(h_one, spins[:1])
            r_up, r_down = jnp.split(r_im, spins[:1])
            uu = orbital_fn(h_up, r_up, **params['up_eq'], **args).reshape(
                spins[0], max(spins), -1
            )
            dd = orbital_fn(h_down, r_down, **params['down_eq'], **args).reshape(
                spins[1], max(spins), -1
            )
            if full_det:
                ud = orbital_fn(h_up, r_up, **params['up_neq'], **args).reshape(
                    spins[0], max(spins), -1
                )
                du = orbital_fn(h_down, r_down, **params['down_neq'], **args).reshape(
                    spins[1], max(spins), -1
                )

        if full_det:
            orbitals = jnp.concatenate(
                [
                    jnp.concatenate([uu[:, : spins[0]], ud[:, : spins[1]]], axis=1),  # type: ignore
                    jnp.concatenate([du[:, : spins[0]], dd[:, : spins[1]]], axis=1),  # type: ignore
                ],
                axis=0,
            )
            orbitals = (orbitals,)
        else:
            orbitals = (uu[:, : spins[0]], dd[:, : spins[1]])
        return tuple(o.transpose(2, 0, 1) for o in orbitals)

    return _orbitals


class OrbitalModule(ReparametrizedModule):
    """
    Base class for all orbital modules.

    Attributes:
    None

    Methods:
    param_spec(shared_orbitals: bool, full_det: bool, orbital_inp: int, determinants: int) -> Dict[str, Dict[str, ParamSpec]]:
        Returns a dictionary of parameter specifications for the orbital module.

    __call__(self, params: ParamTree, h_one: jax.Array, r_im: jax.Array, config: SystemConfigs) -> jax.Array:
        Computes the output of the orbital module.
    """

    @staticmethod
    def param_spec(
        shared_orbitals: bool, full_det: bool, orbital_inp: int, determinants: int
    ):
        """
        Returns a dictionary of parameter specifications for the orbital module.

        Args:
        shared_orbitals: bool
            Whether the orbitals are shared among determinants.
        full_det: bool
            Whether the full determinant is used.
        orbital_inp: int
            The input dimension of the orbitals.
        determinants: int
            The number of determinants.

        Returns:
        SpecTree:
            A dictionary of parameter specifications for the orbital module.
        """
        keys = ('eq', 'neq') if full_det else ('eq',)
        restricted_spec = {
            k: dict(
                orbital_embedding=ParamSpec(
                    ParamType.ORBITAL,
                    shape=(
                        determinants,
                        orbital_inp,
                    ),
                    mean=0,
                    std=1 / jnp.sqrt(orbital_inp),
                    segments=determinants,
                    keep_distr=True,
                    group='orb_embedding',
                ),
                orbital_bias=ParamSpec(
                    ParamType.ORBITAL,
                    shape=(determinants,),
                    mean=0,
                    std=0.1,
                    keep_distr=True,
                    group='orb_bias',
                ),
                envelope=dict(
                    sigma=ParamSpec(
                        ParamType.ORBITAL_NUCLEI,
                        shape=(determinants,),
                        mean=1.0,
                        std=0.1,
                        transform=jnn.softplus,
                        keep_distr=k == 'neq',
                        group='env_sigma',
                    ),
                    pi=ParamSpec(
                        ParamType.ORBITAL_NUCLEI,
                        shape=(determinants,),
                        mean=0.0,
                        std=0.5 if k == 'eq' else 1e-2,
                        transform=jnn.tanh,
                        keep_distr=k == 'neq',
                        group='env_pi',
                        use_bias=False,
                    ),
                    pi_scale=ParamSpec(
                        ParamType.ORBITAL_NUCLEI,
                        shape=(determinants,),
                        mean=inverse_softplus(1.0),
                        std=0.1,
                        use_bias=True,
                        transform=jnn.softplus,
                        keep_distr=k == 'neq',
                        group='env_pi_scale',
                    ),
                ),
            )
            for k in keys
        }
        if shared_orbitals:
            return restricted_spec
        else:
            # For unrestricted we need double the number of parameters
            return {
                f'{spin}_{k}': restricted_spec[k]
                for k in restricted_spec
                for spin in ('up', 'down')
            }

    def __call__(self, params, h_one, r_im, config):
        raise NotImplementedError()

    det_op = staticmethod(jnp.linalg.slogdet)

    @staticmethod
    def match_hf(nn_orbitals, hf_orbitals, config, full_det):
        if full_det:
            return [
                (orbs[0][..., :na, :na], orbs[0][..., na:, na:])
                for orbs, (na, _) in zip(nn_orbitals, config.spins)
            ], [(up, down) for up, down, _ in hf_orbitals]
        return nn_orbitals, [(up, down) for up, down, _ in hf_orbitals]


class ProductOrbitals(OrbitalModule):
    """
    Class representing product orbitals.
    phi_i(r_j) = h_j^T*w_i + b_i
    """

    full_det: bool
    determinants: int
    shared_orbitals: bool

    @nn.compact
    def __call__(self, params, h_one, r_im, config):
        def _orbital_fn(h, r, orbital_embedding, orbital_bias, envelope):
            orbital_embedding = orbital_embedding.reshape(
                -1, self.determinants, h.shape[-1]
            )
            orbital_bias = orbital_bias.reshape(-1, self.determinants)
            # n -> n_elec, o -> n_orb, k -> n_det, d -> inp_dim
            orb = jnp.einsum('nd,okd->nok', h, orbital_embedding) + orbital_bias
            orb *= isotropic_envelope(r, **envelope).reshape(orb.shape)
            return orb

        orbital_fn = make_orbital_fn(_orbital_fn, self.shared_orbitals, self.full_det)
        return eval_orbitals(orbital_fn, params, h_one, r_im, config)


def caylay_transform(x: jax.Array) -> jax.Array:
    x = (x - x.mT) / 2
    I = jnp.eye(x.shape[-1])
    Q = jnp.linalg.solve(x + I, x - I)
    return Q @ Q


def to_skewsymmetric_orthogonal(x: jax.Array):
    # The skew-symmetric identity matrix
    J = block_diag(*[jnp.array([[0, 1], [-1, 0]])] * (x.shape[-1] // 2))
    return skewsymmetric_quadratic(caylay_transform(x), J)


@staticmethod
def match_nn_and_hf_pfaffian(
    nn_orbitals,
    hf_orbitals,
    config,
    cache,
    pretrain_match_steps: int,
    pretrain_match_lr: float,
    pretrain_match_orbitals: bool | float,
    pretrain_match_pfaffian: bool | float,
):
    @jax.vmap
    def hf_match(
        hf_up: jax.Array, hf_down: jax.Array, nn_pf: jax.Array, nn_orb: jax.Array, cache
    ):
        leading_dims = hf_up.shape[:-2]
        n_up, n_down = hf_up.shape[-2], hf_down.shape[-2]
        D = nn_orb.shape[-1] // 2
        if (n_up + n_down) % 2 == 1:
            if n_up > n_down:
                eye = jnp.broadcast_to(
                    jnp.eye(n_down + 1), (*leading_dims, n_down + 1, n_down + 1)
                )
                hf_down = eye.at[..., :n_down, :n_down].set(hf_down)
                n_down += 1
            else:
                eye = jnp.broadcast_to(
                    jnp.eye(n_up + 1), (*leading_dims, n_up + 1, n_up + 1)
                )
                hf_up = eye.at[..., :n_up, :n_up].set(hf_up)
                n_up += 1
        if hf_up.ndim == nn_orb.ndim - 1:
            hf_up = hf_up[..., None, :, :]
            hf_down = hf_down[..., None, :, :]

        # Now we padded correctly
        hf_up_pad = jnp.concatenate(
            [hf_up, jnp.zeros((*hf_up.shape[:-1], D - n_up))],
            axis=-1,
        )
        hf_down_pad = jnp.concatenate(
            [
                hf_down,
                jnp.zeros((*hf_down.shape[:-1], D - n_down)),
            ],
            axis=-1,
        )
        hf_full = block(
            hf_up,
            jnp.zeros((*hf_up.shape[:-1], n_down)),
            jnp.zeros((*hf_down.shape[:-1], n_up)),
            hf_down,
        )
        # Prepare NN
        nn_up = nn_orb[..., :n_up, :D]
        nn_down = nn_orb[..., n_up : n_up + n_down, D:]

        def transforms(x: tuple[jax.Array, ...]):
            x_iter = iter(x)
            result = ()
            if pretrain_match_orbitals:
                hf_t = next(x_iter)
                result += tuple(jax.vmap(caylay_transform)(hf_t))
            else:
                result += (None,) * 2
            if pretrain_match_pfaffian:
                result += (to_skewsymmetric_orthogonal(next(x_iter)),)
            else:
                result += (None,)
            return result

        def get_pairs(x, enable_diff=False):
            up_t, down_t, pf_t = transforms(x)
            hf_out, nn_out = [], []
            up, down, pf = nn_up, nn_down, nn_pf
            if not enable_diff:
                up, down, pf = jax.lax.stop_gradient((up, down, pf))
            if up_t is not None:
                hf_out.append(hf_up_pad @ up_t)
                nn_out.append(up)
            if down_t is not None:
                hf_out.append(hf_down_pad @ down_t)
                nn_out.append(down)
            if pf_t is not None:
                hf_out.append(skewsymmetric_quadratic(hf_full, pf_t))
                nn_out.append(pf)
            return hf_out, nn_out

        def loss(x):
            loss_val = 0
            for hf, nn in zip(*get_pairs(x)):
                loss_val += ((hf - nn) ** 2).mean()
            return loss_val

        def avg_loss_and_grad(x):
            return pmean_if_pmap(jax.value_and_grad(loss)(x))

        def optim(optimizer, x, maxiter=pretrain_match_steps):
            def step(state, i):
                params, opt_state = state
                value, grads = avg_loss_and_grad(params)
                updates, opt_state = optimizer.update(grads, opt_state, params)
                params = optax.apply_updates(params, updates)
                return (params, opt_state), value

            (x, _), _ = jax.lax.scan(step, (x, optimizer.init(x)), jnp.arange(maxiter))
            return x

        optimizer = optax.chain(prodigy(pretrain_match_lr), optax.scale(-1))
        if cache is None:
            params = ()
            if pretrain_match_orbitals:
                params += (jnp.zeros((2, D, D), dtype=jnp.float32),)
            if pretrain_match_pfaffian:
                params += (
                    jnp.zeros((n_up + n_down, n_up + n_down), dtype=jnp.float32),
                )
            cache = (ema_make(params), jnp.zeros((), dtype=jnp.int32))
        params, step = cache
        step += 1
        params = ema_update(
            params,
            optim(optimizer, ema_value(params)),
            0.99,
        )
        hf_out, nn_out = get_pairs(ema_value(params), True)
        if pretrain_match_orbitals and pretrain_match_pfaffian:
            hf_up, hf_down, hf_pf = hf_out
            nn_up, nn_down, nn_pf = nn_out
            hf_up, hf_down, nn_up, nn_down = jtu.tree_map(
                lambda x: x * (pretrain_match_orbitals**0.5),
                (hf_up, hf_down, nn_up, nn_down),
            )
            hf_pf, nn_pf = jtu.tree_map(
                lambda x: x * (pretrain_match_pfaffian**0.5), (hf_pf, nn_pf)
            )
            return (hf_up, hf_down, hf_pf), (nn_up, nn_down, nn_pf), (params, step)
        return tuple(nn_out), tuple(hf_out), (params, step)

    nn_out, hf_out, out_caches = [], [], []
    configs, indices, *_ = unique_configs(config)
    for idx, conf in zip(indices, configs):
        getter = itemgetter(*idx)
        nn, hf, guess = getter(nn_orbitals), getter(hf_orbitals), getter(cache)
        hf_up, hf_down = jtu.tree_map(lambda *x: jnp.stack(x), *hf)
        nn_pf, nn_orb = jtu.tree_map(lambda *x: jnp.stack(x), *nn)
        if guess[0] is not None:
            init_guess = jtu.tree_map(lambda *x: jnp.stack(x), *guess)
        else:
            init_guess = None
        nn_o, hf_o, out_cache = hf_match(hf_up, hf_down, nn_pf, nn_orb, init_guess)
        nn_out += list(nn_o)
        hf_out += list(hf_o)
        for i in range(len(nn_pf)):
            out_caches.append(jtu.tree_map(lambda x: x[i], out_cache))
    out_caches = itemgetter(*inverse_group_idx(config))(out_caches)
    return nn_out, hf_out, out_caches


@staticmethod
def slog_pfaffian_unpack(args):
    return slog_pfaffian_skewsymmetric_quadratic(*args)


class Pfaffian(ReparametrizedModule):
    correlation: str
    determinants: int
    orbitals_per_atom: int
    sigma_per_det: bool
    pi_per_det: bool
    det_precision: str
    down_projection: int | None
    minimal_orbitals: dict[str, int] | int | None = None
    use_spin_mask: bool = False
    use_many_spin_embeddings: bool = False

    @staticmethod
    def param_spec(
        correlation: str,
        determinants: int,
        orbital_inp: int,
        down_projection: int | None,
        orbitals_per_atom: int,
        envelope_per_atom: int,
        sigma_per_det: bool,
        pi_per_det: bool,
        use_spin_mask: bool = False,
        use_many_spin_embeddings: bool = False,
    ):
        inp_dim = down_projection if down_projection else orbital_inp
        result = {
            k: dict(
                orbital_embedding=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(determinants, orbitals_per_atom, inp_dim),
                    mean=0,
                    std=1 / jnp.sqrt(orbital_inp),
                    segments=determinants,
                    keep_distr=True,
                    det_axis=1,
                ),
                spin_embedding=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(determinants, orbitals_per_atom, 5)
                    if use_many_spin_embeddings
                    else (determinants, orbitals_per_atom),
                    mean=0,
                    std=1.0,
                    keep_distr=keep_distr,
                    det_axis=1,
                ),
                sigma=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(
                        determinants,
                        envelope_per_atom,
                    )
                    if sigma_per_det
                    else (envelope_per_atom,),
                    mean=1,
                    std=0.1,
                    transform=jnn.softplus,
                    keep_distr=keep_distr,
                    group='env_sigma',
                    det_axis=1 if sigma_per_det else None,
                ),
                pi=ParamSpec(
                    ParamType.NUCLEI_NUCLEI,
                    shape=(
                        determinants,
                        orbitals_per_atom,
                        envelope_per_atom,
                    )
                    if pi_per_det
                    else (orbitals_per_atom, envelope_per_atom),
                    mean=0.0,
                    std=1e-3 if keep_distr else 1,
                    keep_distr=keep_distr,
                    use_bias=False,
                    gating='sigmoid',
                    det_axis=1 if pi_per_det else None,
                ),
            )
            | (
                dict(
                    spin_mask=ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, 4),
                        mean=0,
                        std=1,
                        transform=jnn.sigmoid,
                        keep_distr=keep_distr,
                        det_axis=1,
                    ),
                )
                if use_spin_mask
                else {}
            )
            for k, keep_distr in zip(('same', 'diff'), (False, True))
        }
        match correlation.lower():
            case 'full':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI_NUCLEI,
                        shape=(
                            determinants,
                            orbitals_per_atom * 2,
                            orbitals_per_atom * 2,
                        ),
                        mean=0,
                        std=1,
                        use_bias=False,
                        det_axis=1,
                    ),
                )
            case 'dense':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI_NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        use_bias=False,
                        det_axis=1,
                    ),
                )
            case 'bidiagonal':
                result = result | {
                    k: ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        det_axis=1,
                    )
                    for k in ('diag1', 'diag2', 'off')
                }
            case 'minimal':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        det_axis=1,
                    )
                )
            case x if x in ('simple', 'static'):
                pass
            case _:
                raise ValueError(f'Unknown correlation type: {correlation}')
        return result

    def atomic_orbital_mask(self, charges):
        if self.minimal_orbitals is None:
            return None
        mask = np.zeros((len(charges) * self.orbitals_per_atom,), dtype=bool)
        for i, charge in enumerate(charges):
            start = i * self.orbitals_per_atom
            if isinstance(self.minimal_orbitals, int):
                add_orbs = self.minimal_orbitals
            else:
                add_orbs = self.minimal_orbitals[str(charge)]
            n_orbs = (charge + 1) // 2 + add_orbs
            assert n_orbs <= self.orbitals_per_atom
            mask[start : start + n_orbs] = True
        return mask

    @nn.compact
    def __call__(
        self, params, h_one, r_im, config, return_atomic_orbitals: bool = False
    ):
        if self.down_projection:
            projection = self.param(
                'projection',
                jnn.initializers.normal(1 / jnp.sqrt(h_one.shape[-1])),
                (self.down_projection, h_one.shape[-1]),
            )

        def _atomic_orbitals(h, r, params, spins, charges):
            n_up, n_down, m = *spins, len(charges)
            # Feature projection
            inp_dim = (
                h.shape[-1] if self.down_projection is None else self.down_projection
            )
            # Prepare parameters
            orbital_embedding = params['orbital_embedding'].reshape(-1, inp_dim)
            ao_mask = self.atomic_orbital_mask(charges)
            pi = params['pi'].reshape(m, m * self.orbitals_per_atom, -1)
            spin_emb = params['spin_embedding'].reshape(m * self.orbitals_per_atom, -1)
            # Mask out unnecessary orbitals
            if ao_mask is not None:
                orbital_embedding = orbital_embedding[ao_mask]
                pi = pi[:, ao_mask]
                spin_emb = spin_emb[ao_mask]
            if 'spin_mask' in params:
                spin_mask = params['spin_mask'].reshape(-1, 4)
                if ao_mask is not None:
                    spin_mask = spin_mask[ao_mask]

            # compute orbitals
            if self.down_projection:
                orbital_embedding = orbital_embedding @ projection
            orb = jnp.einsum('nd,kd->nk', h, orbital_embedding)

            # envelopes
            dist = r[..., -1:].reshape(h.shape[0], m, 1)
            orb *= jnp.einsum(
                'nmd,mod->no',
                jnp.exp(-dist * params['sigma']),
                pi / jnp.sqrt(self.orbitals_per_atom),
            )
            # spin mask
            up, down = orb[:n_up], orb[n_up:]
            if 'spin_mask' in params:
                if n_up > n_down:
                    down *= spin_mask[..., n_up - n_down - 1]
                elif n_up < n_down:
                    up *= spin_mask[..., n_down - n_up - 1]

            # pad to even
            if (n_up + n_down) % 2 == 1:
                if n_up > n_down:
                    if spin_emb.shape[-1] == 1:
                        down = jnp.concatenate([down, spin_emb[:, 0][None]])
                    else:
                        down = jnp.concatenate([down, spin_emb[:, n_up - n_down][None]])
                else:
                    if spin_emb.shape[-1] == 1:
                        up = jnp.concatenate([up, spin_emb[:, 0][None]])
                    else:
                        up = jnp.concatenate([up, spin_emb[:, -1][None]])
            return up, down

        param_axes = jtu.tree_map(
            lambda x: x.det_axis,
            self.param_spec(
                self.correlation,
                1,
                1,
                1,
                1,
                1,
                self.sigma_per_det,
                self.pi_per_det,
                use_spin_mask=self.use_spin_mask,
                use_many_spin_embeddings=self.use_many_spin_embeddings,
            ),
            is_leaf=lambda x: isinstance(x, ParamSpec),
        )

        def full_correlation_pfaffian(correlation):
            m = int(correlation.shape[0] ** 0.5)
            correlation = correlation.reshape(
                m, m, self.orbitals_per_atom, 2, self.orbitals_per_atom, 2
            )
            correlation = jnp.transpose(correlation, (3, 0, 2, 5, 1, 4)).reshape(
                m * self.orbitals_per_atom * 2, m * self.orbitals_per_atom * 2
            )
            correlation = correlation - correlation.mT
            return correlation

        def dense_correlation_pfaffian(correlation):
            m = int(correlation.shape[0] ** 0.5)
            mat = correlation.reshape(
                m, m, self.orbitals_per_atom, self.orbitals_per_atom
            )
            mat = jnp.transpose(mat, (0, 2, 1, 3)).reshape(
                m * self.orbitals_per_atom, m * self.orbitals_per_atom
            )
            A, S = (mat - mat.mT) / 2, (mat + mat.mT) / 2
            correlation = block(A, S, -S, A)
            return correlation

        def diag_correlation_pfaffian(A1, A2, S, n_nuc):
            if A2 is None:
                A2 = A1
            if S is None:
                S = A1
            A1, A2, S = A1 - A1.mT, A2 - A2.mT, S + S.mT

            A1, A2, S = jtu.tree_map(
                lambda x: jnp.repeat(x, n_nuc, axis=0) if x.shape[0] == 1 else x,
                (A1, A2, S),
            )
            correlation = block_diag(*A1, *A2)
            off_corr = block_diag(*S)
            correlation += block(
                jnp.zeros_like(off_corr),
                off_corr,
                -off_corr.T,
                jnp.zeros_like(off_corr),
            )
            return correlation

        @functools.partial(
            jax.vmap, in_axes=(0, None, 0, 0, None, None)
        )  # configurations
        @functools.partial(
            jax.vmap, in_axes=(param_axes, 0, None, None, None, None)
        )  # determinants
        def _orbital_fn(params, correlation, h, r, spins, charges):
            same_up, same_down = _atomic_orbitals(
                h, r, params['same'], spins=spins, charges=charges
            )
            diff_up, diff_down = _atomic_orbitals(
                h, r, params['diff'], spins=spins, charges=charges
            )
            norm = (max(spins) / same_up.shape[-1]) ** 0.5
            same_up, same_down, diff_up, diff_down = jtu.tree_map(
                lambda x: x * norm, (same_up, same_down, diff_up, diff_down)
            )

            if return_atomic_orbitals:
                diff_up, diff_down = jnp.zeros_like(diff_up), jnp.zeros_like(diff_down)
            x = block(same_up, diff_up, diff_down, same_down)

            match self.correlation.lower():
                case 'full':
                    A = full_correlation_pfaffian(params['correlation'])
                case 'dense':
                    A = dense_correlation_pfaffian(params['correlation'])
                case 'bidiagonal':
                    A = diag_correlation_pfaffian(
                        params['diag1'], params['diag2'], params['off'], len(charges)
                    )
                case 'minimal':
                    A = diag_correlation_pfaffian(
                        params['correlation'], None, None, len(charges)
                    )
                case 'static':
                    A = diag_correlation_pfaffian(correlation, None, None, len(charges))
                case 'simple':
                    A = block_diag(*[jnp.array([[0, 1], [-1, 0]])] * (x.shape[-1] // 2))
                case _:
                    raise ValueError(f'Unknown correlation type: {self.correlation}')

            if (ao_mask := self.atomic_orbital_mask(charges)) is not None:
                ao_mask = np.repeat(ao_mask, 2)
                A = A[ao_mask][:, ao_mask]

            assert (
                A.shape[0] == x.shape[-1]
            ), f'Shape mismatch between antisymmetrizer ({A.shape}) and orbitals ({x.shape})'
            if return_atomic_orbitals:
                return skewsymmetric_quadratic(x, A), x
            return ((x.astype(self.det_precision), A),)

        if self.correlation.lower() == 'static':
            correlation = self.param(
                'correlation',
                jnn.initializers.normal(1),
                (self.determinants, 1, self.orbitals_per_atom, self.orbitals_per_atom),
            )
        else:
            correlation = None

        def orbitals(params, h, r, spins, charges):
            return _orbital_fn(params, correlation, h, r, spins, charges)

        return eval_orbitals(orbitals, params, h_one, r_im, config)

    det_op = slog_pfaffian_unpack
    match_hf = match_nn_and_hf_pfaffian


class FullPfaffian(ReparametrizedModule):
    correlation: str
    determinants: int
    orbitals_per_atom: int
    det_precision: str
    down_projection: int | None
    minimal_orbitals: dict[str, int] | int | None = None
    use_spin_mask: bool = False
    use_many_spin_embeddings: bool = False

    @staticmethod
    def param_spec(
        correlation: str,
        determinants: int,
        orbital_inp: int,
        down_projection: int | None,
        orbitals_per_atom: int,
        use_spin_mask: bool = False,
        use_many_spin_embeddings: bool = False,
    ):
        inp_dim = down_projection if down_projection else orbital_inp
        result = {
            k: dict(
                orbital_embedding=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(determinants, orbitals_per_atom, inp_dim),
                    mean=0,
                    std=1 / jnp.sqrt(orbital_inp),
                    segments=determinants,
                    keep_distr=True,
                    det_axis=1,
                ),
                spin_embedding=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(determinants, orbitals_per_atom, 5)
                    if use_many_spin_embeddings
                    else (determinants, orbitals_per_atom),
                    mean=0,
                    std=1.0,
                    keep_distr=keep_distr,
                    det_axis=1,
                ),
                sigma=ParamSpec(
                    ParamType.NUCLEI_NUCLEI,
                    shape=(determinants, orbitals_per_atom),
                    mean=1,
                    std=0.1,
                    transform=jnn.softplus,
                    keep_distr=keep_distr,
                    group='env_sigma',
                    det_axis=1,
                ),
                pi=ParamSpec(
                    ParamType.NUCLEI_NUCLEI,
                    shape=(determinants, orbitals_per_atom),
                    mean=0.0,
                    std=1e-3 if keep_distr else 1,
                    keep_distr=keep_distr,
                    use_bias=False,
                    gating='sigmoid',
                    det_axis=1,
                ),
            )
            | (
                dict(
                    spin_mask=ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, 4),
                        mean=0,
                        std=1,
                        transform=jnn.sigmoid,
                        keep_distr=keep_distr,
                        det_axis=1,
                    ),
                )
                if use_spin_mask
                else {}
            )
            for k, keep_distr in zip(('same', 'diff'), (False, True))
        }
        match correlation.lower():
            case 'full':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI_NUCLEI,
                        shape=(
                            determinants,
                            orbitals_per_atom * 2,
                            orbitals_per_atom * 2,
                        ),
                        mean=0,
                        std=1,
                        use_bias=False,
                        det_axis=1,
                    ),
                )
            case 'dense':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI_NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        use_bias=False,
                        det_axis=1,
                    ),
                )
            case 'bidiagonal':
                result = result | {
                    k: ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        det_axis=1,
                    )
                    for k in ('diag1', 'diag2', 'off')
                }
            case 'minimal':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        det_axis=1,
                    )
                )
            case x if x in ('simple', 'static'):
                pass
            case _:
                raise ValueError(f'Unknown correlation type: {correlation}')
        return result

    def atomic_orbital_mask(self, charges):
        if self.minimal_orbitals is None:
            return None
        mask = np.zeros((len(charges) * self.orbitals_per_atom,), dtype=bool)
        for i, charge in enumerate(charges):
            start = i * self.orbitals_per_atom
            if isinstance(self.minimal_orbitals, int):
                add_orbs = self.minimal_orbitals
            else:
                add_orbs = self.minimal_orbitals[str(charge)]
            n_orbs = (charge + 1) // 2 + add_orbs
            assert n_orbs <= self.orbitals_per_atom
            mask[start : start + n_orbs] = True
        return mask

    @nn.compact
    def __call__(
        self, params, h_one, r_im, config, return_atomic_orbitals: bool = False
    ):
        if self.down_projection:
            projection = self.param(
                'projection',
                jnn.initializers.normal(1 / jnp.sqrt(h_one.shape[-1])),
                (self.down_projection, h_one.shape[-1]),
            )

        def _atomic_orbitals(h, r, params, spins, charges):
            n_up, n_down, m = *spins, len(charges)
            # Feature projection
            inp_dim = (
                h.shape[-1] if self.down_projection is None else self.down_projection
            )
            # Prepare parameters
            orbital_embedding = params['orbital_embedding'].reshape(-1, inp_dim)
            ao_mask = self.atomic_orbital_mask(charges)
            sigma = params['sigma'].reshape(m, m * self.orbitals_per_atom)
            pi = params['pi'].reshape(m, m * self.orbitals_per_atom)
            spin_emb = params['spin_embedding'].reshape(m * self.orbitals_per_atom, -1)
            # Mask out unnecessary orbitals
            if ao_mask is not None:
                orbital_embedding = orbital_embedding[ao_mask]
                sigma = sigma[:, ao_mask]
                pi = pi[:, ao_mask]
                spin_emb = spin_emb[ao_mask]
            if 'spin_mask' in params:
                spin_mask = params['spin_mask'].reshape(-1, 4)
                if ao_mask is not None:
                    spin_mask = spin_mask[ao_mask]

            # compute orbitals
            if self.down_projection:
                orbital_embedding = orbital_embedding @ projection
            orb = jnp.einsum('nd,kd->nk', h, orbital_embedding)

            # envelopes
            dist = r[..., -1:].reshape(h.shape[0], m, 1)
            orb *= jnp.einsum('nmo,mo->no', jnp.exp(-dist * sigma), pi)
            # spin mask
            up, down = orb[:n_up], orb[n_up:]
            if 'spin_mask' in params:
                if n_up > n_down:
                    down *= spin_mask[..., n_up - n_down - 1]
                elif n_up < n_down:
                    up *= spin_mask[..., n_down - n_up - 1]

            # pad to even
            if (n_up + n_down) % 2 == 1:
                if n_up > n_down:
                    if spin_emb.shape[-1] == 1:
                        down = jnp.concatenate([down, spin_emb[:, 0][None]])
                    else:
                        down = jnp.concatenate([down, spin_emb[:, n_up - n_down][None]])
                else:
                    if spin_emb.shape[-1] == 1:
                        up = jnp.concatenate([up, spin_emb[:, 0][None]])
                    else:
                        up = jnp.concatenate([up, spin_emb[:, -1][None]])
            return up, down

        param_axes = jtu.tree_map(
            lambda x: x.det_axis,
            self.param_spec(
                self.correlation,
                1,
                1,
                1,
                1,
                use_spin_mask=self.use_spin_mask,
                use_many_spin_embeddings=self.use_many_spin_embeddings,
            ),
            is_leaf=lambda x: isinstance(x, ParamSpec),
        )

        def full_correlation_pfaffian(correlation):
            m = int(correlation.shape[0] ** 0.5)
            correlation = correlation.reshape(
                m, m, self.orbitals_per_atom, 2, self.orbitals_per_atom, 2
            )
            correlation = jnp.transpose(correlation, (3, 0, 2, 5, 1, 4)).reshape(
                m * self.orbitals_per_atom * 2, m * self.orbitals_per_atom * 2
            )
            correlation = correlation - correlation.mT
            return correlation

        def dense_correlation_pfaffian(correlation):
            m = int(correlation.shape[0] ** 0.5)
            mat = correlation.reshape(
                m, m, self.orbitals_per_atom, self.orbitals_per_atom
            )
            mat = jnp.transpose(mat, (0, 2, 1, 3)).reshape(
                m * self.orbitals_per_atom, m * self.orbitals_per_atom
            )
            A, S = (mat - mat.mT) / 2, (mat + mat.mT) / 2
            correlation = block(A, S, -S, A)
            return correlation

        def diag_correlation_pfaffian(A1, A2, S, n_nuc):
            if A2 is None:
                A2 = A1
            if S is None:
                S = A1
            A1, A2, S = A1 - A1.mT, A2 - A2.mT, S + S.mT

            A1, A2, S = jtu.tree_map(
                lambda x: jnp.repeat(x, n_nuc, axis=0) if x.shape[0] == 1 else x,
                (A1, A2, S),
            )
            correlation = block_diag(*A1, *A2)
            off_corr = block_diag(*S)
            correlation += block(
                jnp.zeros_like(off_corr),
                off_corr,
                -off_corr.T,
                jnp.zeros_like(off_corr),
            )
            return correlation

        @functools.partial(
            jax.vmap, in_axes=(0, None, 0, 0, None, None)
        )  # configurations
        @functools.partial(
            jax.vmap, in_axes=(param_axes, 0, None, None, None, None)
        )  # determinants
        def _orbital_fn(params, correlation, h, r, spins, charges):
            same_up, same_down = _atomic_orbitals(
                h, r, params['same'], spins=spins, charges=charges
            )
            diff_up, diff_down = _atomic_orbitals(
                h, r, params['diff'], spins=spins, charges=charges
            )
            norm = (max(spins) / same_up.shape[-1]) ** 0.5
            same_up, same_down, diff_up, diff_down = jtu.tree_map(
                lambda x: x * norm, (same_up, same_down, diff_up, diff_down)
            )

            if return_atomic_orbitals:
                diff_up, diff_down = jnp.zeros_like(diff_up), jnp.zeros_like(diff_down)
            x = block(same_up, diff_up, diff_down, same_down)

            match self.correlation.lower():
                case 'full':
                    A = full_correlation_pfaffian(params['correlation'])
                case 'dense':
                    A = dense_correlation_pfaffian(params['correlation'])
                case 'bidiagonal':
                    A = diag_correlation_pfaffian(
                        params['diag1'], params['diag2'], params['off'], len(charges)
                    )
                case 'minimal':
                    A = diag_correlation_pfaffian(
                        params['correlation'], None, None, len(charges)
                    )
                case 'static':
                    A = diag_correlation_pfaffian(correlation, None, None, len(charges))
                case 'simple':
                    A = block_diag(*[jnp.array([[0, 1], [-1, 0]])] * (x.shape[-1] // 2))
                case _:
                    raise ValueError(f'Unknown correlation type: {self.correlation}')

            if (ao_mask := self.atomic_orbital_mask(charges)) is not None:
                ao_mask = np.repeat(ao_mask, 2)
                A = A[ao_mask][:, ao_mask]

            assert (
                A.shape[0] == x.shape[-1]
            ), f'Shape mismatch between antisymmetrizer ({A.shape}) and orbitals ({x.shape})'
            if return_atomic_orbitals:
                return skewsymmetric_quadratic(x, A), x
            return ((x.astype(self.det_precision), A),)

        if self.correlation.lower() == 'static':
            correlation = self.param(
                'correlation',
                jnn.initializers.normal(1),
                (self.determinants, 1, self.orbitals_per_atom, self.orbitals_per_atom),
            )
        else:
            correlation = None

        def orbitals(params, h, r, spins, charges):
            return _orbital_fn(params, correlation, h, r, spins, charges)

        return eval_orbitals(orbitals, params, h_one, r_im, config)

    det_op = slog_pfaffian_unpack
    match_hf = match_nn_and_hf_pfaffian


class BottleneckPfaffian(ReparametrizedModule):
    correlation: str
    determinants: int
    orbitals_per_atom: int
    det_precision: str
    down_projection: int | None
    minimal_orbitals: dict[str, int] | int | None = None
    use_spin_mask: bool = False
    use_many_spin_embeddings: bool = False

    @staticmethod
    def param_spec(
        correlation: str,
        determinants: int,
        orbital_inp: int,
        down_projection: int | None,
        orbitals_per_atom: int,
        bottleneck_envelopes: int,
        use_spin_mask: bool = False,
        use_many_spin_embeddings: bool = False,
    ):
        inp_dim = down_projection if down_projection else orbital_inp
        result = {
            k: dict(
                orbital_embedding=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(determinants, orbitals_per_atom, inp_dim),
                    mean=0,
                    std=1 / jnp.sqrt(orbital_inp),
                    segments=determinants,
                    keep_distr=True,
                    det_axis=1,
                ),
                spin_embedding=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(determinants, orbitals_per_atom, 5)
                    if use_many_spin_embeddings
                    else (determinants, orbitals_per_atom),
                    mean=0,
                    std=1.0,
                    keep_distr=keep_distr,
                    det_axis=1,
                ),
                sigma=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(bottleneck_envelopes,),
                    mean=1,
                    std=0.1,
                    transform=jnn.softplus,
                    keep_distr=keep_distr,
                    group='env_sigma',
                    det_axis=None,
                ),
                pi=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(bottleneck_envelopes,),
                    mean=0.0,
                    std=1,
                    keep_distr=keep_distr,
                    gating='sigmoid',
                    det_axis=None,
                ),
                env_w=ParamSpec(
                    ParamType.NUCLEI,
                    shape=(determinants, orbitals_per_atom, bottleneck_envelopes),
                    mean=0.0,
                    std=1e-3 if keep_distr else 1,
                    keep_distr=keep_distr,
                    det_axis=1,
                ),
            )
            | (
                dict(
                    spin_mask=ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, 4),
                        mean=0,
                        std=1,
                        transform=jnn.sigmoid,
                        keep_distr=keep_distr,
                        det_axis=1,
                    ),
                )
                if use_spin_mask
                else {}
            )
            for k, keep_distr in zip(('same', 'diff'), (False, True))
        }
        match correlation.lower():
            case 'full':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI_NUCLEI,
                        shape=(
                            determinants,
                            orbitals_per_atom * 2,
                            orbitals_per_atom * 2,
                        ),
                        mean=0,
                        std=1,
                        use_bias=False,
                        det_axis=1,
                    ),
                )
            case 'dense':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI_NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        use_bias=False,
                        det_axis=1,
                    ),
                )
            case 'bidiagonal':
                result = result | {
                    k: ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        det_axis=1,
                    )
                    for k in ('diag1', 'diag2', 'off')
                }
            case 'minimal':
                result = result | dict(
                    correlation=ParamSpec(
                        ParamType.NUCLEI,
                        shape=(determinants, orbitals_per_atom, orbitals_per_atom),
                        mean=0,
                        std=1,
                        det_axis=1,
                    )
                )
            case x if x in ('simple', 'static'):
                pass
            case _:
                raise ValueError(f'Unknown correlation type: {correlation}')
        return result

    def atomic_orbital_mask(self, charges):
        if self.minimal_orbitals is None:
            return None
        mask = np.zeros((len(charges) * self.orbitals_per_atom,), dtype=bool)
        for i, charge in enumerate(charges):
            start = i * self.orbitals_per_atom
            if isinstance(self.minimal_orbitals, int):
                add_orbs = self.minimal_orbitals
            else:
                add_orbs = self.minimal_orbitals[str(charge)]
            n_orbs = (charge + 1) // 2 + add_orbs
            assert n_orbs <= self.orbitals_per_atom
            mask[start : start + n_orbs] = True
        return mask

    @nn.compact
    def __call__(
        self, params, h_one, r_im, config, return_atomic_orbitals: bool = False
    ):
        if self.down_projection:
            projection = self.param(
                'projection',
                jnn.initializers.normal(1 / jnp.sqrt(h_one.shape[-1])),
                (self.down_projection, h_one.shape[-1]),
            )

        def _atomic_orbitals(h, r, params, spins, charges):
            n_up, n_down, m = *spins, len(charges)
            # Feature projection
            inp_dim = (
                h.shape[-1] if self.down_projection is None else self.down_projection
            )
            # Prepare parameters
            orbital_embedding = params['orbital_embedding'].reshape(-1, inp_dim)
            ao_mask = self.atomic_orbital_mask(charges)
            spin_emb = params['spin_embedding'].reshape(m * self.orbitals_per_atom, -1)
            # Mask out unnecessary orbitals
            if ao_mask is not None:
                orbital_embedding = orbital_embedding[ao_mask]
                spin_emb = spin_emb[ao_mask]
            if 'spin_mask' in params:
                spin_mask = params['spin_mask'].reshape(-1, 4)
                if ao_mask is not None:
                    spin_mask = spin_mask[ao_mask]

            # compute orbitals
            if self.down_projection:
                orbital_embedding = orbital_embedding @ projection
            orb = jnp.einsum('nd,kd->nk', h, orbital_embedding)

            # envelopes
            dist = r[..., -1:].reshape(h.shape[0], m, 1)
            envelopes = jnp.einsum(
                'nmd,md->nd',
                jnp.exp(-dist * params['sigma']),
                params['pi'],
            )
            orb *= jnp.einsum(
                'nd,od->no',
                envelopes,
                params['env_w'].reshape(-1, envelopes.shape[-1])[ao_mask],
            )
            # spin mask
            up, down = orb[:n_up], orb[n_up:]
            if 'spin_mask' in params:
                if n_up > n_down:
                    down *= spin_mask[..., n_up - n_down - 1]
                elif n_up < n_down:
                    up *= spin_mask[..., n_down - n_up - 1]

            # pad to even
            if (n_up + n_down) % 2 == 1:
                if n_up > n_down:
                    if spin_emb.shape[-1] == 1:
                        down = jnp.concatenate([down, spin_emb[:, 0][None]])
                    else:
                        down = jnp.concatenate([down, spin_emb[:, n_up - n_down][None]])
                else:
                    if spin_emb.shape[-1] == 1:
                        up = jnp.concatenate([up, spin_emb[:, 0][None]])
                    else:
                        up = jnp.concatenate([up, spin_emb[:, -1][None]])
            return up, down

        param_axes = jtu.tree_map(
            lambda x: x.det_axis,
            self.param_spec(
                self.correlation,
                1,
                1,
                1,
                1,
                1,
                use_spin_mask=self.use_spin_mask,
                use_many_spin_embeddings=self.use_many_spin_embeddings,
            ),
            is_leaf=lambda x: isinstance(x, ParamSpec),
        )

        def full_correlation_pfaffian(correlation):
            m = int(correlation.shape[0] ** 0.5)
            correlation = correlation.reshape(
                m, m, self.orbitals_per_atom, 2, self.orbitals_per_atom, 2
            )
            correlation = jnp.transpose(correlation, (3, 0, 2, 5, 1, 4)).reshape(
                m * self.orbitals_per_atom * 2, m * self.orbitals_per_atom * 2
            )
            correlation = correlation - correlation.mT
            return correlation

        def dense_correlation_pfaffian(correlation):
            m = int(correlation.shape[0] ** 0.5)
            mat = correlation.reshape(
                m, m, self.orbitals_per_atom, self.orbitals_per_atom
            )
            mat = jnp.transpose(mat, (0, 2, 1, 3)).reshape(
                m * self.orbitals_per_atom, m * self.orbitals_per_atom
            )
            A, S = (mat - mat.mT) / 2, (mat + mat.mT) / 2
            correlation = block(A, S, -S, A)
            return correlation

        def diag_correlation_pfaffian(A1, A2, S, n_nuc):
            if A2 is None:
                A2 = A1
            if S is None:
                S = A1
            A1, A2, S = A1 - A1.mT, A2 - A2.mT, S + S.mT

            A1, A2, S = jtu.tree_map(
                lambda x: jnp.repeat(x, n_nuc, axis=0) if x.shape[0] == 1 else x,
                (A1, A2, S),
            )
            correlation = block_diag(*A1, *A2)
            off_corr = block_diag(*S)
            correlation += block(
                jnp.zeros_like(off_corr),
                off_corr,
                -off_corr.T,
                jnp.zeros_like(off_corr),
            )
            return correlation

        @functools.partial(
            jax.vmap, in_axes=(0, None, 0, 0, None, None)
        )  # configurations
        @functools.partial(
            jax.vmap, in_axes=(param_axes, 0, None, None, None, None)
        )  # determinants
        def _orbital_fn(params, correlation, h, r, spins, charges):
            same_up, same_down = _atomic_orbitals(
                h, r, params['same'], spins=spins, charges=charges
            )
            diff_up, diff_down = _atomic_orbitals(
                h, r, params['diff'], spins=spins, charges=charges
            )
            norm = (max(spins) / same_up.shape[-1]) ** 0.5
            same_up, same_down, diff_up, diff_down = jtu.tree_map(
                lambda x: x * norm, (same_up, same_down, diff_up, diff_down)
            )

            if return_atomic_orbitals:
                diff_up, diff_down = jnp.zeros_like(diff_up), jnp.zeros_like(diff_down)
            x = block(same_up, diff_up, diff_down, same_down)

            match self.correlation.lower():
                case 'full':
                    A = full_correlation_pfaffian(params['correlation'])
                case 'dense':
                    A = dense_correlation_pfaffian(params['correlation'])
                case 'bidiagonal':
                    A = diag_correlation_pfaffian(
                        params['diag1'], params['diag2'], params['off'], len(charges)
                    )
                case 'minimal':
                    A = diag_correlation_pfaffian(
                        params['correlation'], None, None, len(charges)
                    )
                case 'static':
                    A = diag_correlation_pfaffian(correlation, None, None, len(charges))
                case 'simple':
                    A = block_diag(*[jnp.array([[0, 1], [-1, 0]])] * (x.shape[-1] // 2))
                case _:
                    raise ValueError(f'Unknown correlation type: {self.correlation}')

            if (ao_mask := self.atomic_orbital_mask(charges)) is not None:
                ao_mask = np.repeat(ao_mask, 2)
                A = A[ao_mask][:, ao_mask]

            assert (
                A.shape[0] == x.shape[-1]
            ), f'Shape mismatch between antisymmetrizer ({A.shape}) and orbitals ({x.shape})'
            if return_atomic_orbitals:
                return skewsymmetric_quadratic(x, A), x
            return ((x.astype(self.det_precision), A),)

        if self.correlation.lower() == 'static':
            correlation = self.param(
                'correlation',
                jnn.initializers.normal(1),
                (self.determinants, 1, self.orbitals_per_atom, self.orbitals_per_atom),
            )
        else:
            correlation = None

        def orbitals(params, h, r, spins, charges):
            return _orbital_fn(params, correlation, h, r, spins, charges)

        return eval_orbitals(orbitals, params, h_one, r_im, config)

    det_op = slog_pfaffian_unpack
    match_hf = match_nn_and_hf_pfaffian


class ExplicitPfaffian(ReparametrizedModule):
    determinants: int
    sigma_per_det: bool
    envelope_per_atom: int
    pi_per_det: bool
    det_precision: str
    hidden_dims: Sequence[int] = (64, 64)

    @staticmethod
    def param_spec(
        determinants: int,
        envelope_per_atom: int,
        sigma_per_det: bool,
        pi_per_det: bool,
    ):
        result = dict(
            sigma=ParamSpec(
                ParamType.NUCLEI,
                shape=(
                    determinants,
                    envelope_per_atom,
                )
                if sigma_per_det
                else (envelope_per_atom,),
                mean=1,
                std=0.1,
                transform=jnn.softplus,
                group='env_sigma',
                det_axis=1 if sigma_per_det else None,
            ),
            pi=ParamSpec(
                ParamType.NUCLEI,
                shape=(
                    determinants,
                    envelope_per_atom,
                )
                if pi_per_det
                else (envelope_per_atom,),
                mean=0.0,
                std=1,
                use_bias=False,
                gating='sigmoid',
                det_axis=1 if pi_per_det else None,
            ),
        )
        return result

    @nn.compact
    def __call__(
        self, params, h_one, r_im, config, return_atomic_orbitals: bool = False
    ):
        first_dense = {
            'W1': self.param(
                'first_W1',
                jnn.initializers.lecun_normal(),
                (self.determinants, h_one.shape[-1], self.hidden_dims[0]),
            ),
            'W2': self.param(
                'first_W2',
                jnn.initializers.lecun_normal(),
                (self.determinants, h_one.shape[-1], self.hidden_dims[0]),
            ),
            'b': self.param(
                'first_b',
                jnn.initializers.zeros,
                (
                    self.determinants,
                    self.hidden_dims[0],
                ),
            ),
        }
        mlp = [
            {
                'W': self.param(
                    f'W{i}',
                    jnn.initializers.lecun_normal(),
                    (self.determinants, inp, out),
                ),
                'b': self.param(
                    f'b{i}',
                    jnn.initializers.zeros,
                    (
                        self.determinants,
                        out,
                    ),
                ),
            }
            for i, (inp, out) in enumerate(
                zip(self.hidden_dims, [*self.hidden_dims[1:], 1])
            )
        ]
        mlp_params = (first_dense, mlp)

        def _atomic_orbitals(h, r, params, spins, charges, mlp_p):
            m = len(charges)
            # orbitals
            first, mlp = mlp_p
            x = (h @ first['W1'])[:, None] + h @ first['W2'] + first['b']
            for layer in mlp:
                x = jnn.silu(x) @ layer['W'] + layer['b']
            orbitals = x.squeeze(-1)
            # envelopes
            dist = r[..., -1:].reshape(h.shape[0], m, 1)
            envelopes = jnp.einsum(
                'nmd,md->n',
                jnp.exp(-dist * params['sigma']),
                params['pi'],
            )
            orbitals *= envelopes * envelopes[None]
            return orbitals - orbitals.mT

        param_axes = jtu.tree_map(
            lambda x: x.det_axis,
            self.param_spec(1, 1, self.sigma_per_det, self.pi_per_det),
            is_leaf=lambda x: isinstance(x, ParamSpec),
        )

        @functools.partial(
            jax.vmap, in_axes=(0, 0, 0, None, None, None)
        )  # configurations
        @functools.partial(
            jax.vmap, in_axes=(param_axes, None, None, None, None, 0)
        )  # determinants
        def _orbital_fn(params, h, r, spins, charges, mlp_p):
            A = _atomic_orbitals(
                h, r, params, spins=spins, charges=charges, mlp_p=mlp_p
            )

            if return_atomic_orbitals:
                return A, jnp.zeros_like(A)

            return (A.astype(self.det_precision),)

        def orbitals(params, h, r, spins, charges):
            return _orbital_fn(params, h, r, spins, charges, mlp_params)

        return eval_orbitals(orbitals, params, h_one, r_im, config)

    det_op = staticmethod(slog_pfaffian)
    match_hf = match_nn_and_hf_pfaffian


class AGP(ReparametrizedModule):
    determinants: int
    orbitals_per_atom: int
    sigma_per_det: bool
    pi_per_det: bool
    det_precision: str
    down_projection: int | None
    minimal_orbitals: dict[str, int] | int | None = None

    @staticmethod
    def param_spec(
        determinants: int,
        orbital_inp: int,
        down_projection: int | None,
        orbitals_per_atom: int,
        envelope_per_atom: int,
        sigma_per_det: bool,
        pi_per_det: bool,
    ):
        inp_dim = down_projection if down_projection else orbital_inp
        result = dict(
            orbital_embedding=ParamSpec(
                ParamType.NUCLEI,
                shape=(determinants, orbitals_per_atom, inp_dim),
                mean=0,
                std=1 / jnp.sqrt(orbital_inp),
                segments=determinants,
                keep_distr=True,
                det_axis=1,
            ),
            sigma=ParamSpec(
                ParamType.NUCLEI,
                shape=(
                    determinants,
                    envelope_per_atom,
                )
                if sigma_per_det
                else (envelope_per_atom,),
                mean=1,
                std=0.1,
                transform=jnn.softplus,
                group='env_sigma',
                det_axis=1 if sigma_per_det else None,
            ),
            pi=ParamSpec(
                ParamType.NUCLEI_NUCLEI,
                shape=(
                    determinants,
                    orbitals_per_atom,
                    envelope_per_atom,
                )
                if pi_per_det
                else (orbitals_per_atom, envelope_per_atom),
                mean=0.0,
                std=1,
                use_bias=False,
                gating='sigmoid',
                det_axis=1 if pi_per_det else None,
            ),
        )
        return result

    def atomic_orbital_mask(self, charges):
        if self.minimal_orbitals is None:
            return None
        mask = np.zeros((len(charges) * self.orbitals_per_atom,), dtype=bool)
        for i, charge in enumerate(charges):
            start = i * self.orbitals_per_atom
            if isinstance(self.minimal_orbitals, int):
                add_orbs = self.minimal_orbitals
            else:
                add_orbs = self.minimal_orbitals[str(charge)]
            n_orbs = (charge + 1) // 2 + add_orbs
            assert n_orbs <= self.orbitals_per_atom
            mask[start : start + n_orbs] = True
        return mask

    @nn.compact
    def __call__(
        self, params, h_one, r_im, config, return_atomic_orbitals: bool = False
    ):
        if self.down_projection:
            projection = self.param(
                'projection',
                jnn.initializers.normal(1 / jnp.sqrt(h_one.shape[-1])),
                (self.down_projection, h_one.shape[-1]),
            )

        def _atomic_orbitals(h, r, params, spins, charges):
            n_up, _, m = *spins, len(charges)
            # Feature projection
            inp_dim = (
                h.shape[-1] if self.down_projection is None else self.down_projection
            )
            # Prepare parameters
            orbital_embedding = params['orbital_embedding'].reshape(-1, inp_dim)
            ao_mask = self.atomic_orbital_mask(charges)
            pi = params['pi'].reshape(m, m * self.orbitals_per_atom, -1)
            # Mask out unnecessary orbitals
            if ao_mask is not None:
                orbital_embedding = orbital_embedding[ao_mask]
                pi = pi[:, ao_mask]

            # compute orbitals
            if self.down_projection:
                orbital_embedding = orbital_embedding @ projection
            orb = jnp.einsum('nd,kd->nk', h, orbital_embedding)

            # envelopes
            dist = r[..., -1:].reshape(h.shape[0], m, 1)
            orb *= jnp.einsum(
                'nmd,mod->no',
                jnp.exp(-dist * params['sigma']),
                pi / jnp.sqrt(self.orbitals_per_atom),
            )
            return orb[:n_up], orb[n_up:]

        param_axes = jtu.tree_map(
            lambda x: x.det_axis,
            self.param_spec(
                1,
                1,
                1,
                1,
                1,
                self.sigma_per_det,
                self.pi_per_det,
            ),
            is_leaf=lambda x: isinstance(x, ParamSpec),
        )

        @functools.partial(jax.vmap, in_axes=(0, 0, 0, None, None))  # configurations
        @functools.partial(
            jax.vmap, in_axes=(param_axes, None, None, None, None)
        )  # determinants
        def _orbital_fn(params, h, r, spins, charges):
            up, down = _atomic_orbitals(h, r, params, spins=spins, charges=charges)
            assert up.shape == down.shape
            norm = (max(spins) / up.shape[-1]) ** 0.5
            up, down = jtu.tree_map(lambda x: x * norm, (up, down))
            x = block(up, jnp.zeros_like(up), jnp.zeros_like(down), down)

            D = up.shape[-1]
            A = block(jnp.zeros((D, D)), jnp.eye(D), -jnp.eye(D), jnp.zeros((D, D)))

            if return_atomic_orbitals:
                return skewsymmetric_quadratic(x, A), x
            return (up @ down.mT,)

        def orbitals(params, h, r, spins, charges):
            return _orbital_fn(params, h, r, spins, charges)

        return eval_orbitals(orbitals, params, h_one, r_im, config)

    det_op = staticmethod(jnp.linalg.slogdet)
    match_hf = match_nn_and_hf_pfaffian


class ProductAGP(OrbitalModule):
    full_det: bool
    determinants: int
    shared_orbitals: bool

    @staticmethod
    def param_spec(
        shared_orbitals: bool, full_det: bool, orbital_inp: int, determinants: int
    ):
        return OrbitalModule.param_spec(
            shared_orbitals, False, orbital_inp, determinants
        )

    @nn.compact
    def __call__(self, params, h_one, r_im, config):
        def _orbital_fn(h, r, orbital_embedding, orbital_bias, envelope):
            orbital_embedding = orbital_embedding.reshape(
                -1, self.determinants, h.shape[-1]
            )
            orbital_bias = orbital_bias.reshape(-1, self.determinants)
            # n -> n_elec, o -> n_orb, k -> n_det, d -> inp_dim
            orb = jnp.einsum('nd,okd->nok', h, orbital_embedding) + orbital_bias
            orb *= isotropic_envelope(r, **envelope).reshape(orb.shape)
            return orb

        orbital_fn = make_orbital_fn(_orbital_fn, self.shared_orbitals, self.full_det)
        result = eval_orbitals(orbital_fn, params, h_one, r_im, config)
        return [(jnp.einsum('...uk,...dk->...ud', *orbitals),) for orbitals in result]

    @staticmethod
    def match_hf(nn_orbitals, hf_orbitals, config, full_det):
        return nn_orbitals, [
            (jnp.einsum('...uk,...dk->...ud', up, down),) for up, down in hf_orbitals
        ]


def _get_orbital_edges(
    nuc: jax.Array, valency: tuple[int, ...] | jax.Array
) -> tuple[jax.Array, jax.Array, int]:
    """
    Given a set of nuclei and their valencies, returns the indices of the edges
    connecting the nuclei in the molecular graph.

    Args:
    - nuc: A 2D array of shape (n_atoms, 3) representing the positions of the nuclei.
    - valency: A tuple of length n_atoms representing the valency of each atom.

    Returns:
    - A tuple of three elements:
        - idx_i: A 1D array of shape (n_edges,) representing the indices of the first node of each edge.
        - idx_j: A 1D array of shape (n_edges,) representing the indices of the second node of each edge.
        - N: An integer representing the total number of edges in the molecular graph.
    """
    idx_i, idx_j = jnp.triu_indices(len(valency))
    n_iter = int(np.ceil(sum(valency) / 2))
    nuc = jnp.array(nuc, dtype=jnp.float32)
    valency = jnp.array(valency, dtype=jnp.float32)
    dists = jnp.linalg.norm(nuc[:, None] - nuc, axis=-1)
    dists += jnp.eye(len(valency)) * 15
    dists = dists[idx_i, idx_j]
    counts = jnp.zeros_like(dists, dtype=jnp.int32)

    centered = nuc - nuc.mean(0)
    edge_pos = (centered[:, None] + centered) / 2
    xyz = edge_pos[idx_i, idx_j]
    r = jnp.linalg.norm(xyz, axis=-1)

    def _select_next(carry, _):
        val, counts = carry
        valid = (val[idx_i] > 0) * (val[idx_j] > 0)
        scores = valid / (dists + counts.astype(dists.dtype) / 2)
        mask = jnp.ones_like(scores, dtype=bool)
        for crit in [scores, r, xyz[..., 0], xyz[..., 1], xyz[..., 2]]:
            crit: jnp.ndarray = jnp.where(mask, crit, -np.inf)  # type: ignore
            max_val = crit.max()
            mask = jnp.abs(crit - max_val) < 1e-5
        idx = jnp.argmax(mask)

        count = counts[idx]
        counts = counts.at[idx].add(1)
        val = val.at[idx_i[idx]].add(-1).at[idx_j[idx]].add(-1)
        return (val, counts), (idx_i[idx], idx_j[idx], count)

    (valency, counts), (first, second, N) = jax.lax.scan(
        _select_next, (valency, counts), np.arange(n_iter)
    )
    return first, second, N


def get_valence_orbitals(
    nuclei: jax.Array, valency: tuple[int, ...] | np.ndarray, config: SystemConfigs
):
    """
    Given a set of nuclei and their valencies, returns the indices of the edges
    connecting the nuclei in the molecular graph.

    Args:
    - nuclei: A 2D array of shape (n_atoms, 3) representing the positions of the nuclei.
    - valency: A tuple of length n_atoms representing the valency of each atom.
    - config: A SystemConfigs object representing the configuration of the system.

    Returns:
    - A tuple of four elements:
        - idx_i: A 1D array of shape (n_edges,) representing the indices of the first node of each edge.
        - idx_j: A 1D array of shape (n_edges,) representing the indices of the second node of each edge.
        - type: A 1D array of shape (n_edges,) representing the bond type.
        - counts: A 1D array of shape (n_graphs,) representing the number of orbitals per graph.
    """
    n_nuclei = config.n_nuc
    offsets = np.cumsum([0, *n_nuclei[:-1]])
    valency = np.array(valency)

    # We process all graphs with the same signature simultaneously
    idx_i, idx_j, N = [], [], []
    for nuc, val, off in zip(
        group_by_config(config, nuclei, lambda s, c: len(c)),
        group_by_config(config, valency, lambda s, c: len(c)),
        group_by_config(config, offsets, lambda s, c: 1),
    ):
        i, j, n = jax.vmap(_get_orbital_edges, in_axes=(0, None))(nuc, tuple(val[0]))
        idx_i.append(i + off)
        idx_j.append(j + off)
        N.append(n)
    # Also keep track of the number of orbitals per graph
    counts = jtu.tree_map(
        lambda x: np.full((x.shape[0],), x.shape[1], dtype=np.int32), N
    )
    # We have to rearrange the output here to match the input order
    chained = tuple(map(chain, (idx_i, idx_j, N, counts)))
    reverse_idx = inverse_group_idx(config)
    result = [itemgetter(*reverse_idx)(r) for r in chained]
    return (*tuple(map(jnp.concatenate, result[:-1])), np.stack(result[-1]))


def get_core_orbitals(nuclei: jax.Array, config: SystemConfigs):
    """
    Given a set of nuclei, returns the core orbitals for all molecular graphs.

    Args:
    - nuclei: A 2D array of shape (n_atoms, 3) representing the positions of the nuclei.
    - config: A SystemConfigs object representing the configuration of the system.

    Returns:
    - A tuple of five elements:
        - core_i: A 1D array of shape (n_core_orbitals,) representing the indices of the nuclei associated with each core orbital.
        - core_loc: A 2D array of shape (n_core_orbitals, 3) representing the positions of the core orbitals.
        - core_type: A 1D array of shape (n_core_orbitals,) representing the type of each core orbital.
        - core_N: A 1D array of shape (n_core_orbitals,) representing the index of each core orbital within its associated nucleus.
        - core_N_orb: A 1D array of shape (n_graphs,) representing the number of core orbitals per graph.
    """
    n_nuclei = config.n_nuc
    flat_charges = np.array(tuple(flatten(config.charges)))
    valency = np.array(itemgetter(*flat_charges)(VALENCY))
    n_core_orbitals = (flat_charges - valency) // 2
    core_loc = nuclei.repeat(n_core_orbitals, axis=0)
    core_type = np.concatenate(
        (
            [
                np.arange(o) + CORE_OFFSETS[c]
                for o, c in zip(n_core_orbitals, flat_charges)
            ]
        )
    )

    core_i = np.concatenate([np.full(c, i) for i, c in enumerate(n_core_orbitals)])
    core_N = np.concatenate([np.arange(c) for c in n_core_orbitals])
    core_N_orb = np_segment_sum(
        n_core_orbitals, np.repeat(np.arange(len(n_nuclei)), n_nuclei)
    )
    return core_i, core_loc, core_type, core_N, core_N_orb


_concat = jax.jit(functools.partial(jnp.concatenate, axis=0))


def get_orbitals(nuclei: jax.Array, config: SystemConfigs):
    """
    Given a set of nuclei and a SystemConfigs object, returns the orbitals for all molecular graphs.

    Args:
    - nuclei: A 2D array of shape (n_atoms, 3) representing the positions of the nuclei.
    - config: A SystemConfigs object representing the configuration of the system.

    Returns:
    - A tuple of four elements:
        - orb_loc: A 2D array of shape (n_orbitals, 3) representing the positions of the orbitals.
        - orb_type: A 1D array of shape (n_orbitals,) representing the type of each orbital.
        - orb_assoc: A 2D array of shape (n_orbitals, 2) representing the indices of the nuclei associated with each orbital.
        - N_orbs: A 1D array of shape (n_graphs,) representing the number of orbitals per graph.
    """
    flat_charges = np.array(tuple(flatten(config.charges)))
    valency = np.array(itemgetter(*flat_charges)(VALENCY))

    core_i, core_loc, core_type, core_N, core_N_orb = get_core_orbitals(nuclei, config)
    core_ij = jnp.stack([core_i, core_i], -1)

    val_i, val_j, val_type, val_N_orb = get_valence_orbitals(
        nuclei, tuple(valency), config
    )
    val_ij = jnp.stack([val_i, val_j], -1)
    val_loc = (nuclei[val_i] + nuclei[val_j]) / 2

    # For the valence orbital type we check the charges of both atoms
    # and then use their unique combination.
    # f_c = jnp.array(flat_charges, dtype=jnp.int32)
    # val_type = jnp.array(VAL_OFFSET, dtype=jnp.int32)[(f_c[val_i], f_c[val_j])] + val_type + MAX_CORE_ORB
    val_type = val_type + MAX_CORE_ORB

    N_orbs = core_N_orb + val_N_orb
    orb_loc = []
    orb_type = []
    orb_assoc = []
    off_c, off_v = 0, 0
    for c, v in zip(core_N_orb, val_N_orb):
        orb_loc.append(core_loc[off_c : off_c + c])
        orb_loc.append(val_loc[off_v : off_v + v])
        orb_type.append(core_type[off_c : off_c + c])
        orb_type.append(val_type[off_v : off_v + v])
        orb_assoc.append(core_ij[off_c : off_c + c])
        orb_assoc.append(val_ij[off_v : off_v + v])
        off_c, off_v = off_c + c, off_v + v
    orb_loc = _concat(orb_loc)
    orb_type = _concat(orb_type)
    orb_assoc = _concat(orb_assoc)

    assert sum(N_orbs) == len(orb_type) == len(orb_assoc)
    return orb_loc, orb_type, orb_assoc, N_orbs
