import functools
from collections.abc import Sequence
from typing import TYPE_CHECKING, Protocol, TypeVar, cast

import einops
import jax
import jax.numpy as jnp
import numpy as np
from flax.struct import PyTreeNode, field
from jaxtyping import Array, ArrayLike, Bool, DTypeLike, Float, Int
from optax import Schedule

from neural_pfaffian.nn.wave_function import (
    GeneralizedWaveFunction,
    WaveFunctionParameters,
)
from neural_pfaffian.sample_reweighting import (
    ReweightingFactor,
    SampleMask,
    compute_logpsi,
    convert_reweighted_tensor_to_unreweighted,
    pack_reweighted_tensor,
    unpack_reweighted_tensor,
)
from neural_pfaffian.systems import Systems
from neural_pfaffian.utils import EMA, Modules
from neural_pfaffian.utils.cg import cg
from neural_pfaffian.utils.jax_utils import (
    _BATCH_AXIS,
    jit,
    pall_to_all,
    pgather,
    pidx,
    pmean,
    pmean_if_pmap,
    psum_if_pmap,
    vmap,
)
from neural_pfaffian.utils.schedule import ScheduleConfig, get_schedule
from neural_pfaffian.utils.segment_utils import unsegment_axis
from neural_pfaffian.utils.summary_stats import weighted_centering
from neural_pfaffian.utils.tree_utils import (
    tree_add,
    tree_dot,
    tree_mul,
    tree_squared_norm,
    tree_sub,
    tree_to_dtype,
)

if TYPE_CHECKING:
    from neural_pfaffian.vmc import ReweightedLocalEnergy

PS = TypeVar('PS')


class Preconditioner(Protocol[PS]):
    def init(
        self,
        key: Array,
        params: WaveFunctionParameters,
        systems: Systems,
    ) -> PS: ...

    def apply(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: PS,
        auxiliary_grads: WaveFunctionParameters,
        *,
        sample_reweighting: bool = False,
    ) -> tuple[WaveFunctionParameters, PS, dict[str, Float[Array, '']]]: ...


class Identity(PyTreeNode, Preconditioner[None]):
    wave_function: GeneralizedWaveFunction = field(pytree_node=False)
    dtype: DTypeLike | None = field(pytree_node=False, default=None)
    """The dtype to compute the gradient in. Careful! Contrary to other
    preconditioners this one will return the gradient in `dtype` and not cast back
    into the input types."""

    def init(self, key: Array, params: WaveFunctionParameters, systems: Systems) -> None:
        return None

    @jit(static_argnames=('sample_reweighting',))
    def apply(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: None,
        auxiliary_grads: WaveFunctionParameters,
        *,
        sample_reweighting: bool = False,
    ) -> tuple[WaveFunctionParameters, None, dict[str, Array]]:
        apply_fn = (
            self._apply_reweighted if sample_reweighting else self._apply_unreweighted
        )
        if self.dtype is not None:
            params, systems, dL_dlogpsi = tree_to_dtype(
                (params, systems, dL_dlogpsi),
                self.dtype,
            )

        return apply_fn(
            params,
            systems,
            dL_dlogpsi,
            reweighting_factor,
            sample_mask,
            state,
            auxiliary_grads,
        )

    def _apply_reweighted(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: None,
        auxiliary_grads: WaveFunctionParameters,
    ):
        active_N_per_mol = psum_if_pmap(jnp.sum(sample_mask, axis=0, keepdims=True))
        N = active_N_per_mol * dL_dlogpsi.shape[-1]
        out_dtypes = jax.tree.map(lambda x: x.dtype, params)

        def log_p_closure(params):
            logpsis = compute_logpsi(self.wave_function, params, systems)[0]
            logpsis = pack_reweighted_tensor(logpsis, systems)
            logpsis = jnp.where(sample_mask, logpsis, 0.0)  # (batch_size, n_mols)
            return logpsis / N

        _, vjp_fn = jax.vjp(log_p_closure, params)

        def center_fn(x):
            x = x.reshape(dL_dlogpsi.shape)
            return weighted_centering(
                x,
                sample_mask,
                reweighting_factor,
                data_is_reweighted=True,
            )

        grad = psum_if_pmap(vjp_fn(center_fn(dL_dlogpsi).astype(dL_dlogpsi.dtype))[0])
        grad = tree_add(grad, auxiliary_grads)
        grad = jax.tree.map(jax.lax.convert_element_type, grad, out_dtypes)

        return grad, state, {}

    def _apply_unreweighted(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: None,
        auxiliary_grads: WaveFunctionParameters,
    ):
        active_N_per_mol = psum_if_pmap(jnp.sum(sample_mask, axis=0, keepdims=True))
        N = active_N_per_mol * dL_dlogpsi.shape[-1]
        out_dtypes = jax.tree.map(lambda x: x.dtype, params)

        def log_p_closure(params):
            return (
                jnp.where(
                    sample_mask,
                    self.wave_function.batched_apply(params, systems),
                    0.0,
                )
                / N
            )

        _, vjp_fn = jax.vjp(log_p_closure, params)

        def center_fn(x):
            x = x.reshape(dL_dlogpsi.shape)
            return weighted_centering(x, sample_mask)

        grad = psum_if_pmap(vjp_fn(center_fn(dL_dlogpsi).astype(dL_dlogpsi.dtype))[0])
        grad = tree_add(grad, auxiliary_grads)

        update = jax.tree.map(jax.lax.convert_element_type, grad, out_dtypes)

        return update, state, {}


class CGState(PyTreeNode):
    last_grad: WaveFunctionParameters
    damping: Float[Array, '']


class CG(PyTreeNode, Preconditioner[CGState]):
    wave_function: GeneralizedWaveFunction = field(pytree_node=False)
    damping: Float[ArrayLike, '']
    decay_factor: Float[ArrayLike, '']
    maxiter: int = field(pytree_node=False)
    precondition_aux_grads: bool = field(pytree_node=False, default=True)

    def init(self, key: Array, params: WaveFunctionParameters, systems: Systems):
        return CGState(
            last_grad=jax.tree.map(lambda x: jnp.zeros_like(x), params),
            damping=jnp.asarray(self.damping, dtype=jnp.float32),
        )

    @jit(static_argnames=('sample_reweighting',))
    def apply(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: CGState,
        auxiliary_grads: WaveFunctionParameters,
        *,
        sample_reweighting: bool = False,
    ):
        _apply_fn = (
            self._apply_reweighted if sample_reweighting else self._apply_unreweighted
        )
        return _apply_fn(
            params,
            systems,
            dL_dlogpsi,
            reweighting_factor,
            sample_mask,
            state,
            auxiliary_grads,
        )

    def _apply_reweighted(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: 'ReweightedLocalEnergy',
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: CGState,
        auxiliary_grads: WaveFunctionParameters,
    ):
        n_dev = jax.device_count()
        active_N_per_mol = psum_if_pmap(jnp.sum(sample_mask, axis=0, keepdims=True))
        N = active_N_per_mol * dL_dlogpsi.shape[-1]
        normalization = 1 / jnp.sqrt(N)

        @jit
        def log_p_closure(params):
            logpsis = compute_logpsi(self.wave_function, params, systems)[0]
            logpsis = pack_reweighted_tensor(logpsis, systems)
            logpsis = jnp.where(sample_mask, logpsis, 0.0)
            return logpsis * normalization

        _, vjp_fn = jax.vjp(log_p_closure, params)
        _, jvp = jax.linearize(log_p_closure, params)

        def vjp(x):
            return psum_if_pmap(
                vjp_fn(x.reshape(dL_dlogpsi.shape).astype(dL_dlogpsi.dtype))[0],
            )

        # dL_dlogpsi is already centered
        grad = Identity(self.wave_function).apply(
            params,
            systems,
            dL_dlogpsi,
            reweighting_factor,
            sample_mask,
            None,
            auxiliary_grads,
            sample_reweighting=True,
        )[0]
        grad_types = jax.tree.map(lambda x: x.dtype, params)
        aux_grad = jax.tree.map(
            jax.lax.convert_element_type,
            auxiliary_grads,
            grad_types,
        )

        if self.precondition_aux_grads:
            grad = tree_add(grad, aux_grad)

        last_grad = state.last_grad
        last_grad = jax.tree.map(jax.lax.convert_element_type, last_grad, grad)
        decayed_last_grad = tree_mul(last_grad, self.decay_factor)
        b = tree_add(grad, tree_mul(decayed_last_grad, state.damping))
        avg_grad = vjp(reweighting_factor * normalization)

        @jit
        def Fisher_matmul(v):
            # J^T J v
            jv = jvp(v) - tree_dot(avg_grad, v)
            inp = jv * reweighting_factor
            result = vjp(inp)
            result = tree_sub(result, tree_mul(avg_grad, inp.sum()))
            # add damping
            result = tree_add(result, tree_mul(v, state.damping))
            return result

        # Compute natural gradient
        natgrad = cg(
            A=Fisher_matmul,
            b=b,
            x0=last_grad,
            fixed_iter=n_dev > 1,  # multi gpu
            maxiter=self.maxiter,
        )[0]

        if not self.precondition_aux_grads:
            natgrad = tree_add(natgrad, aux_grad)

        aux_data = {
            'grad_norm': tree_squared_norm(grad) ** 0.5,
            'natgrad_norm': tree_squared_norm(natgrad) ** 0.5,
            'decayed_last_grad_norm': tree_squared_norm(decayed_last_grad) ** 0.5,
        }
        return (
            natgrad,
            state.replace(last_grad=natgrad),
            aux_data,
        )

    def _apply_unreweighted(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: CGState,
        auxiliary_grads: WaveFunctionParameters,
    ):
        n_dev = jax.device_count()
        active_N_per_mol = psum_if_pmap(jnp.sum(sample_mask, axis=0, keepdims=True))
        N = active_N_per_mol * dL_dlogpsi.shape[-1]
        normalization = 1 / jnp.sqrt(N)

        def log_p_closure(p):
            log_p = 2 * self.wave_function.batched_apply(p, systems) * normalization
            return jnp.where(sample_mask, log_p, 0.0)

        _, vjp_fn = jax.vjp(log_p_closure, params)
        _, jvp_fn = jax.linearize(log_p_closure, params)

        def center_fn(x):
            x = x.reshape(dL_dlogpsi.shape)
            return weighted_centering(x, sample_mask)

        def vjp(x):
            return psum_if_pmap(vjp_fn(center_fn(x).astype(dL_dlogpsi.dtype))[0])

        def jvp(x):
            return center_fn(jvp_fn(x))

        grad = psum_if_pmap(vjp(dL_dlogpsi * normalization))
        grad_types = jax.tree.map(lambda x: x.dtype, params)
        aux_grad = jax.tree.map(
            jax.lax.convert_element_type,
            auxiliary_grads,
            grad_types,
        )

        if self.precondition_aux_grads:
            grad = tree_add(grad, aux_grad)

        last_grad = state.last_grad
        last_grad = jax.tree.map(jax.lax.convert_element_type, last_grad, grad)
        decayed_last_grad = tree_mul(last_grad, self.decay_factor)
        b = tree_add(grad, tree_mul(decayed_last_grad, state.damping))

        @jit
        def Fisher_matmul(v):
            # J^T J v
            result = vjp(jvp(v))
            # add damping
            result = tree_add(result, tree_mul(v, state.damping))
            return result

        # Compute natural gradient
        natgrad = cg(
            A=Fisher_matmul,
            b=b,
            x0=last_grad,
            fixed_iter=n_dev > 1,  # multi gpu
            maxiter=self.maxiter,
        )[0]

        if not self.precondition_aux_grads:
            natgrad = tree_add(natgrad, aux_grad)

        aux_data = {
            'grad_norm': tree_squared_norm(grad) ** 0.5,
            'natgrad_norm': tree_squared_norm(natgrad) ** 0.5,
            'decayed_last_grad_norm': tree_squared_norm(decayed_last_grad) ** 0.5,
        }
        return (
            natgrad,
            state.replace(last_grad=natgrad),
            aux_data,
        )


def batch_parameters(
    arrays: Sequence[jax.Array],
    *extras: Sequence[Array],
    batch_size: int,
):
    current_set = [arrays[0]]
    current_ext_set = [[ext[0]] for ext in extras]
    current_size = arrays[0].shape[-1]

    def _yield(current_set, current_ext_set):
        # Skipping concatenation so no extra buffer is created for huge matrices.
        if len(current_set) == 1:
            return (current_set[0], *[ext[0] for ext in current_ext_set])
        return (
            jnp.concatenate(current_set, axis=-1),
            *[jnp.concatenate(ext, axis=-1) for ext in current_ext_set],
        )

    for arr, *ext in zip(arrays[1:], *[ext[1:] for ext in extras], strict=True):
        n = arr.shape[-1]
        if current_size + n > batch_size:
            yield _yield(current_set, current_ext_set)
            current_set = [arr]
            current_ext_set = [[e] for e in ext]
            current_size = n
        else:
            current_set.append(arr)
            for s, e in zip(current_ext_set, ext, strict=True):
                s.append(e)
            current_size += n
    yield _yield(current_set, current_ext_set)


class SpringState(PyTreeNode):
    last_grad: WaveFunctionParameters
    damping: Float[Array, '']


class Spring(PyTreeNode, Preconditioner[SpringState]):
    wave_function: GeneralizedWaveFunction = field(pytree_node=False)
    damping: Float[ArrayLike, '']
    decay_factor: Float[ArrayLike, '']
    aux_grad_cutoff: Float[ArrayLike, '']
    aux_grad_damping: Float[ArrayLike, '']
    aux_grad_global_damping: Float[ArrayLike, '']
    dtype: DTypeLike | None = field(pytree_node=False)
    clip_eigenvals: Float[ArrayLike, '']
    cutoff_to_zero: bool = field(pytree_node=False, default=True)
    max_acc_size: int = field(
        pytree_node=False,
        default=262_144_000,
    )  # max. 2000 MiB concat buffer

    def init(
        self,
        key: Array,
        params: WaveFunctionParameters,
        systems: Systems,
    ) -> SpringState:
        return SpringState(
            last_grad=jax.tree.map(
                lambda x: jnp.zeros_like(x, dtype=self.dtype),
                params,
            ),
            damping=jnp.asarray(self.damping, dtype=jnp.float32),
        )

    @jit(static_argnames=('sample_reweighting',))
    def apply(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: SpringState,
        auxiliary_grads: WaveFunctionParameters,
        *,
        sample_reweighting: bool = False,
    ):
        _apply_fn = (
            self._apply_reweighted if sample_reweighting else self._apply_unreweighted
        )
        return _apply_fn(
            params,
            systems,
            dL_dlogpsi,
            reweighting_factor,
            sample_mask,
            state,
            auxiliary_grads,
        )

    def _apply_reweighted(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: 'ReweightedLocalEnergy',
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: SpringState,
        auxiliary_grads: WaveFunctionParameters,
    ):
        """In SPRING we are not able to directly precondition the (sample_src*walker, n_mols)
        cotangent, since this would require us to form a (sample_src*walker, sample_src*walker) preconditioner.

        Instead, we use the standard Fisher and precondition a (walker, n_mols) cotangent, which has
        been centered using importance sampling."""

        # Convert to unreweighted cotangent
        dL_dlogpsi /= reweighting_factor
        dL_dlogpsi = convert_reweighted_tensor_to_unreweighted(dL_dlogpsi, systems)
        unreweighted_sample_mask = convert_reweighted_tensor_to_unreweighted(
            sample_mask,
            systems,
        )

        return self._apply_unreweighted(
            params,
            systems,
            dL_dlogpsi,
            reweighting_factor,
            unreweighted_sample_mask,
            state,
            auxiliary_grads,
        )

    def _apply_unreweighted(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: SampleMask,
        state: SpringState,
        auxiliary_grads: WaveFunctionParameters,
    ):
        n_dev = jax.device_count()
        shape_N = dL_dlogpsi.size * n_dev  # total number of samples
        active_N_per_mol = psum_if_pmap(jnp.sum(sample_mask, axis=0, keepdims=True))
        active_N = active_N_per_mol * dL_dlogpsi.shape[-1]
        masked_normalization = jnp.broadcast_to(1 / jnp.sqrt(active_N), dL_dlogpsi.shape)
        normalization = 1 / jnp.sqrt(shape_N)

        out_dtypes = jax.tree.map(lambda x: x.dtype, params)
        if self.dtype is not None:
            params, systems, dL_dlogpsi = tree_to_dtype(
                (params, systems, dL_dlogpsi),
                self.dtype,
            )

        @jit
        def log_p(params, systems):
            log_p = self.wave_function.apply(params, systems) * normalization
            return log_p

        @jit
        def log_p_closure(params):
            batch_log_p = jax.vmap(log_p, in_axes=(None, systems.electron_vmap))(
                params,
                systems,
            )  # (batch_size, n_mols)
            return batch_log_p

        def vjp(x):
            return psum_if_pmap(jax.vjp(log_p_closure, params)[1](center_fn(x))[0])

        def jvp(x):
            return center_fn(jax.jvp(log_p_closure, (params,), (x,))[1])

        def center_fn(
            x: Float[Array, 'batch_size n_mols'],
        ) -> Float[Array, 'batch_size n_mols']:
            x = x.reshape(dL_dlogpsi.shape)
            center = pmean(jnp.mean(x, axis=0))
            return x - center

        jacs: list[list[jax.Array]] = []
        segments: list = []
        for sub_systems in systems.iter_stacked_sub_systems():

            @vmap(in_axes=(None, sub_systems.electron_vmap, None))  # walker
            @vmap(in_axes=(None, sub_systems.electron_vmap, 0))  # n_states
            @vmap(in_axes=(None, sub_systems.molecule_vmap, None))  # n_unique_mol
            @jax.grad
            def jac_fn(params, systems, excitation):
                _log_p = log_p(
                    params,
                    systems.replace(excitations=excitation[..., None], mol_ids=(0,)),
                ).sum()
                return _log_p

            # Reduce n_mols dimension to save compute / memory
            electrons = sub_systems.electrons  # (walker, n_mols, n_electrons, 3)
            electrons = unsegment_axis(
                electrons,
                sub_systems.mol_id_groups,
                axis=1,
                num_segments=sub_systems.n_unique_mols,
            )  # (walker, n_unique_mols, n_states, n_electrons, 3)
            electrons = jnp.swapaxes(
                electrons,
                1,
                2,
            )  # (walker, n_states, n_unique_mols, n_electrons, 3)
            nuclei = sub_systems.nuclei  # (n_mols, n_nuclei, 3)
            nuclei = unsegment_axis(
                nuclei,
                sub_systems.mol_id_groups,
                axis=0,
                num_segments=sub_systems.n_unique_mols,
            )  # (n_unique_mols, n_states, n_nuclei, 3)
            # The nuclei are the same for all states, so we can just take the first one
            nuclei = nuclei[:, 0, ...]  # (n_unique_mols, n_nuclei, 3)

            _sub_systems = sub_systems.replace(
                electrons=electrons,
                nuclei=nuclei,
            )

            jacs.append(
                jac_fn(
                    params,
                    _sub_systems,
                    # HACK: Assumes sorted and identical excitations
                    jnp.arange(_sub_systems.max_num_states, dtype=jnp.int32),
                ),
            )
            segments.append((sub_systems.mol_id_groups, sub_systems.excitations))

        @jit
        def concat_jacobians(*jacs: jax.Array) -> jax.Array:
            # merge all systems into a single jacobian
            # each jac is (N, n_states, n_unique_mol, params)
            _jacs = []
            for sub_jac, (mol_id_groups, excitations) in zip(
                jacs,
                segments,
                strict=True,
            ):
                # First resegment the (..., n_states, n_unique_mol, ...) dims based on the
                # individual groups segmentation
                sub_jac = sub_jac[:, excitations, mol_id_groups]
                # jac_array is now (N, n_mols, *params) -> flatten params
                _jacs.append(sub_jac.reshape(*sub_jac.shape[:2], -1))
            jac = jnp.concatenate(_jacs, axis=1)[:, systems.inverse_unique_indices]
            return jac

        jac = jax.tree.map(concat_jacobians, *jacs)
        jac_tensors = jax.tree.leaves(jac)

        @jit
        def to_covariance(jac_: tuple[jax.Array, ...]) -> jax.Array:
            jac, *_ = jac_
            n_params = jac.shape[-1]
            # check for parameters that are not split evenly across devices
            num_even = n_params // n_dev * n_dev
            jac, remainder = jac[..., :num_even], jac[..., num_even:]
            jac: Array = pall_to_all(jac, split_axis=2, concat_axis=0, tiled=True)
            jac -= jac.mean(axis=0)
            jac = jac.reshape(shape_N, -1)

            # no need to materalize an NxN constant zero matrix!
            if n_params % n_dev == 0:
                return jac @ jac.T

            # for the remainder we copy it to all devices
            remainder = pgather(remainder, axis=0, tiled=True)
            remainder -= remainder.mean(axis=0)
            remainder = remainder.reshape(shape_N, -1)
            # Since the remainder is summed n_dev times we need to divide by n_dev
            return jac @ jac.T + remainder @ remainder.T / n_dev

        JT_J = psum_if_pmap(
            sum(
                map(
                    to_covariance,
                    batch_parameters(
                        jac_tensors,
                        # concat buffers are [N, batch_size]
                        batch_size=self.max_acc_size // shape_N,
                    ),
                ),
            ),
        )

        # Constructing the Fisher matrix
        T = (JT_J + JT_J.T) / 2
        s, U = jnp.linalg.eigh(T)
        damping = jnp.maximum(s[-1] / 1e10, state.damping)
        s = jnp.maximum(s, 0)  # Ensure positive definiteness
        damped_s = s + damping
        log10_condition = jnp.log10(damped_s[-1]) - jnp.log10(
            damped_s[0],
        )

        # Collect aux grads
        if self.dtype is not None:
            auxiliary_grads = tree_to_dtype(auxiliary_grads, self.dtype)

        # Process aux grads
        epsilon_aux = (
            pgather(jvp(auxiliary_grads), axis=0, tiled=True)
            .astype(self.dtype)
            .reshape(-1)
        )
        aux_coeffs = U.reshape(n_dev, -1, shape_N)[pidx()] @ jnp.where(
            s < self.aux_grad_cutoff,
            0 if self.cutoff_to_zero else (U.T @ epsilon_aux) / damped_s,
            (U.T @ epsilon_aux)
            / (damped_s * (s + self.aux_grad_damping) + self.aux_grad_global_damping),
        )

        # Process momentum
        decayed_last_grad = tree_mul(state.last_grad, self.decay_factor)
        # Process energy gradient
        epsilon_E = jnp.where(sample_mask, dL_dlogpsi, 0.0) * masked_normalization - jvp(
            decayed_last_grad,
        )
        epsilon_E = pgather(epsilon_E, axis=0, tiled=True).astype(self.dtype).reshape(-1)

        x = (U.T @ epsilon_E) / damped_s
        x = jnp.where(s > self.clip_eigenvals, x, 0.0)
        x = U.reshape(n_dev, -1, shape_N)[pidx()] @ x
        preconditioned = vjp(x + aux_coeffs)

        natgrad = tree_add(preconditioned, decayed_last_grad)

        aux_data = {
            'log10_cond': log10_condition,
            'largest_eigenvalue': s[-1],
            'lowest_eigenvalue': s[0],
            'epsilon_E_norm': jnp.linalg.norm(epsilon_E),
            'dL_dlogpsi_norm': jnp.sqrt(psum_if_pmap(jnp.sum(dL_dlogpsi**2))),
            'T_inv_epsilon_norm': jnp.sqrt(psum_if_pmap(jnp.sum(x**2))),
            'epsilon_aux_norm': jnp.linalg.norm(epsilon_aux),
            'aux_coeffs_norm': jnp.sqrt(psum_if_pmap(jnp.sum(aux_coeffs**2))),
            'preconditioned_grad_norm': tree_squared_norm(preconditioned) ** 0.5,
            'natgrad_norm': tree_squared_norm(natgrad) ** 0.5,
        }
        # Convert back to the original dtype
        update = jax.tree.map(jax.lax.convert_element_type, natgrad, out_dtypes)

        return update, state.replace(last_grad=natgrad), aux_data


@functools.partial(jax.jit, static_argnames=('num_iters'))
@functools.partial(jnp.vectorize, signature='(N,M),()->(N),(M)', excluded={2})
def kron_factors_2d(V: jax.Array, n_factors: jax.Array, num_iters: int = 5):
    # Computes a rough approximation of the Kronecker factors via power iteration.
    def normalize(x: jax.Array):
        x_norm: jax.Array = jnp.linalg.norm(x)
        x_norm = jnp.where(x_norm < 1e-8, 1.0, x_norm)
        return x / x_norm

    def inner(x, _):
        a, b = x
        b = normalize(V @ a)
        a = normalize(V.T @ b)
        return (a, b), None

    a, b = normalize(V[0]), normalize(V[:, 0])
    (a, b), _ = jax.lax.scan(inner, (a, b), None, length=num_iters)
    s = jnp.einsum('a,ab,b->', b, V, a)
    rel_size = 1 - 1 / n_factors
    a *= s**rel_size
    b *= s ** (1 - rel_size)
    return b, a


def kron_factors(V: jax.Array, num_iters: int = 5) -> tuple[jax.Array, ...]:
    match V.ndim:
        case 0:
            return (V.reshape(1),)
        case 1:
            return (V,)
        case 2:
            return kron_factors_2d(V, 2, num_iters)
        case int(n) if n > 2:
            inp = V.reshape(V.shape[0], -1)
            x, remainder = kron_factors_2d(inp, V.ndim, num_iters=num_iters)
            return (
                x,
                *kron_factors(remainder.reshape(V.shape[1:]), num_iters=num_iters),
            )
        case _:
            raise ValueError(f'Unsupported number of dimensions: {V.ndim}')


def kron_factors_to_full_tensor(*x: jax.Array) -> jax.Array:
    letters = ''.join(chr(ord('a') + i) for i in range(len(x)))
    return jnp.einsum(f'{",".join(letters)}->{letters}', *x)


def rank_k_kron_factors(
    V: jax.Array,
    k: int,
    num_iters: int = 1,
) -> tuple[jax.Array, ...]:
    if V.ndim < 2:
        return tuple(x[None] for x in kron_factors(V, num_iters=num_iters))

    def step(carry: Array, x: Array):
        del x
        estimate = carry
        err = V - estimate
        factors = kron_factors(err, num_iters=num_iters)
        new_estimate = estimate + kron_factors_to_full_tensor(*factors)
        return new_estimate, factors

    _, factors = jax.lax.scan(step, jnp.zeros_like(V), jnp.arange(k))
    return factors


def pi_adjusted_kronecker_factors(
    *factors: Array,
    damping: Float[ArrayLike, ''],
) -> tuple[Array, ...]:
    # https://github.com/google-deepmind/kfac-jax/blob/main/kfac_jax/_src/curvature_blocks/kronecker_factored.py
    """Computes Kronecker factors with pi-adjusted factored damping.

    The `f1 kron f2 kron ... kron fn + damping * I` is not a Kronecker product
    in general, because of the added identity. [1] proposed a pi-adjusted factored
    damping approach to approximate it as a Kronecker product. [2] generalized
    this approach from two to tree factors, and [3] generalized it to arbitrary
    numbers of factors. This function implements the generalized approach.

    [1] - https://arxiv.org/abs/1503.05671
    [2] - https://openreview.net/forum?id=SkkTMpjex
    [3] - https://ui.adsabs.harvard.edu/abs/2021arXiv210602925R/abstract

    Args:
      *factors: A list of factors represented as 2D arrays, vectors (which are
        interpreted as representing the diagonal of a matrix) or scalars (which
        are interpreted as being a 1x1 matrix). All factors must be PSD.
      damping: The weight of the identity added to the Kronecker product.

    Returns:
      A list of factors with the same length as `factors`, and with the same
      corresponding representations, whose Kronecker product approximates
      `(f1 kron f2 kron ... kron fn) + damping * I` according to the
      pi-adjusted factored-damping approach.
    """

    # The implementation writes each single factor as `c_i u_i`, where the matrix
    # `u_i` is such that `trace(u_i) / dim(u_i) = 1`. We then factor out all the
    # scalar factors `c_i` into a single overall scaling coefficient and
    # distribute the damping to each single non-scalar factor `u_i` equally.
    norms = jnp.array([jnp.trace(f) / f.shape[0] for f in factors])

    k = len(factors)

    def regular_case() -> tuple[Array, ...]:
        num_non_scalars = sum(1 if f.size != 1 else 0 for f in factors)
        # Compute the normalized factors `u_i`, such that Trace(u_i) / dim(u_i) = 1
        us = [fi / ni for fi, ni in zip(factors, norms, strict=True)]

        if num_non_scalars != 0:
            # Distribute c and damping/c among k factors, where c = jnp.prod(norms),
            # satisfying kron(factors) = c * kron(us).

            # NOTE: c_k (geometric mean of norms) can also be calculated by
            # c ** (1/k) = jnp.prod(norms) ** (1 / len(norms)), but this alternative
            # can make the result zero due to the multiplication of (potentially)
            # small values, i.e. jnp.prod(norms).
            c_k = jnp.exp(jnp.mean(jnp.log(norms)))
            d_k = jnp.power(damping, 1.0 / k) / c_k

            if k > num_non_scalars:
                c_non_scalar = c_k ** (float(k) / num_non_scalars)
                # We distribute the damping only inside the non-scalar factors
                d_hat = jnp.power(damping, 1.0 / num_non_scalars) / c_non_scalar
            else:
                d_hat = d_k

        else:
            # This could cause under/overflow, but it's unavoidable here.
            c = jnp.prod(jnp.array(norms))
            # In the case where all factors are scalar we need to add the damping and
            # then take the k-th root
            c_k = jnp.power(c + damping, 1.0 / k)

        u_hats = []

        for u in us:
            if u.size == 1:  # scalar case
                u_hat = jnp.ones_like(u)  # damping not used in the scalar factors
            elif u.ndim == 2:
                u_hat = u + d_hat * jnp.eye(u.shape[0], dtype=u.dtype)  # type: ignore
            else:  # diagonal case
                assert u.ndim == 1
                u_hat = u + d_hat  # type: ignore
            u_hats.append(u_hat * c_k)

        return tuple(u_hats)

    def zero_case() -> tuple[Array, ...]:
        # In the special case where for some reason one of the factors is zero, then
        # the we write each factor as `damping^(1/k) * I`.
        c_k = jnp.power(damping, 1.0 / k)
        return tuple(
            c_k
            * (
                jnp.eye(fi.shape[0], dtype=fi.dtype)
                if fi.ndim == 2
                else jnp.ones_like(fi)
            )
            for fi in factors
        )

    return jax.lax.cond(jnp.greater(jnp.min(norms), 0.0), regular_case, zero_case)


def _reconstruct_jac_from_factors(factors: tuple[jax.Array, ...]) -> jax.Array:
    m = len(factors)
    if m == 0:
        raise ValueError('No modes to reconstruct.')
    # Build an einsum like 'brA, brB, brC -> bABC'
    letters = ''.join(chr(ord('a') + i) for i in range(m))
    lhs = ','.join([f'BR{letters[i]}' for i in range(m)])
    rhs = f'B{letters}'
    return jnp.einsum(f'{lhs}->{rhs}', *factors)


def _reconstruction_error(
    jac: jax.Array,
    factors: tuple[jax.Array, ...],
) -> Float[Array, '']:
    """Computes the mean reconstruction error over all samples"""

    def _flatten(x):
        return einops.rearrange(
            x,
            'walker n_mol rank dim_length -> (walker n_mol) rank dim_length',
        )

    flat_factors = jax.tree.map(_flatten, factors)
    approx_jac = _reconstruct_jac_from_factors(flat_factors)
    flat_jac = jac.reshape(-1, *jac.shape[2:])

    def _frob(x):
        # Frob norm over all param dims
        return jnp.linalg.norm(x.reshape(x.shape[0], -1), axis=-1)

    err = _frob(approx_jac - flat_jac)
    base = _frob(flat_jac)
    return pmean_if_pmap(jnp.mean(err / (base + 1e-12)))


class KroneckerFactors(PyTreeNode):
    covariances: tuple[jax.Array, ...]
    reconstruction_error: Float[Array, ''] = field(
        pytree_node=True,
        default_factory=lambda: jnp.zeros((), dtype=jnp.float32),
    )

    @classmethod
    def from_jac(
        cls,
        jac: jax.Array,
        rank: int = 1,
        num_iters: int = 5,
    ):
        param_dims = jac.shape[2:]
        if any(s > 2048 for s in param_dims):
            print(
                'Warning: KroneckerFactors.from_jac received a Jacobian with large parameter dimensions. ',
            )
            return KroneckerFactors(())
        if np.prod(param_dims) > (512**2):
            print(
                f'Warning: KroneckerFactors.from_jac received a Jacobian with large parameter dimensions {param_dims}. '
                'This may lead to memory issues.',
            )
            return KroneckerFactors(())
        jac -= jnp.mean(jac, axis=0, keepdims=True)  # Center the jacobian
        kron_fn = functools.partial(rank_k_kron_factors, k=rank, num_iters=num_iters)
        factors = jax.vmap(jax.vmap(kron_fn))(jac)

        error = _reconstruction_error(jac, factors)

        def to_covariance(x: jax.Array) -> jax.Array:
            x = x.reshape(-1, x.shape[-1])
            return pmean_if_pmap(x.T @ x) / np.prod(jac.shape[:2])

        return KroneckerFactors(jax.tree.map(to_covariance, factors), error)

    def damped_covariances(self, damping: jax.Array):
        return pi_adjusted_kronecker_factors(*self.covariances, damping=damping)

    def __call__(
        self,
        gradient: jax.Array,
        damping: jax.Array,
        *,
        reconstruction_tol: float = 0.1,
    ) -> jax.Array:
        if len(self.covariances) == 0:
            return gradient

        def _precondition_grad(gradient, damping):
            covs = self.damped_covariances(damping)
            if len(covs) == 1:
                return jnp.linalg.solve(covs[0], gradient.reshape(-1)).reshape(
                    gradient.shape,
                )
            for i in range(len(covs)):
                cov = covs[i]
                gradient = jnp.moveaxis(gradient, i, -2)
                gradient = jnp.linalg.solve(cov, gradient)
                gradient = jnp.moveaxis(gradient, -2, i)
            return gradient

        return jax.lax.cond(
            jnp.greater(self.reconstruction_error, reconstruction_tol),
            lambda: gradient,
            lambda: _precondition_grad(gradient, damping),
        )


class KFACState(PyTreeNode):
    factors: EMA[WaveFunctionParameters]
    last_grad: WaveFunctionParameters
    norm_constraint: Float[Array, '']
    step: Int[Array, '']


class KFAC(PyTreeNode, Preconditioner[KFACState]):
    wave_function: GeneralizedWaveFunction = field(pytree_node=False)
    damping_schedule: Schedule = field(pytree_node=False)
    ema: Float[ArrayLike, '']
    lr_schedule: Schedule = field(pytree_node=False)
    norm_constraint: Float[ArrayLike, '']
    norm_constraint_decay: Float[ArrayLike, '']
    decay_factor: Float[ArrayLike, '']
    dtype: DTypeLike | None = field(pytree_node=False, default=None)
    fisher_reweighting: bool = field(pytree_node=False, default=False)
    kronecker_rank: int = field(pytree_node=False, default=1)
    num_iters: int = field(pytree_node=False, default=5)
    reconstruction_tol: float = field(pytree_node=False, default=0.1)

    @staticmethod
    def create(
        wave_function: GeneralizedWaveFunction,
        damping_schedule: ScheduleConfig,
        ema: Float[ArrayLike, ''],
        lr_schedule: ScheduleConfig,
        norm_constraint: Float[ArrayLike, ''],
        norm_constraint_decay: Float[ArrayLike, ''],
        decay_factor: Float[ArrayLike, ''],
        dtype: DTypeLike | None = None,
        fisher_reweighting: bool = False,
        kronecker_rank: int = 1,
        num_iters: int = 5,
        reconstruction_tol: float = 0.1,
    ) -> 'KFAC':
        return KFAC(
            wave_function=wave_function,
            damping_schedule=get_schedule(damping_schedule),
            ema=ema,
            lr_schedule=get_schedule(lr_schedule),
            norm_constraint=norm_constraint,
            norm_constraint_decay=norm_constraint_decay,
            decay_factor=decay_factor,
            dtype=dtype,
            fisher_reweighting=fisher_reweighting,
            kronecker_rank=kronecker_rank,
            num_iters=num_iters,
            reconstruction_tol=reconstruction_tol,
        )

    def init(
        self,
        key: Array,
        params: WaveFunctionParameters,
        systems: Systems,
    ) -> KFACState:
        zeros = jax.tree.map(lambda x: jnp.zeros((1, 1, *x.shape), dtype=x.dtype), params)
        last_grad = jax.tree.map(lambda x: jnp.zeros_like(x), params)
        if self.dtype is not None:
            zeros = tree_to_dtype(zeros, self.dtype)
            last_grad = tree_to_dtype(last_grad, self.dtype)
        return KFACState(
            factors=EMA.init(
                jax.tree.map(KroneckerFactors.from_jac, zeros),
            ),
            last_grad=last_grad,
            norm_constraint=jnp.array(self.norm_constraint, dtype=jnp.float32),
            step=jnp.zeros((), dtype=jnp.int32),
        )

    # TODO: Sample mask!
    def _unreweighted_jac(
        self,
        params,
        systems,
        reweighting_factor,
        sample_mask,
    ):
        @jit
        def log_p(params, systems):
            return self.wave_function.apply(params, systems)  # type: ignore

        jacs: list[list[jax.Array]] = []
        for sub_systems in systems.iter_stacked_sub_systems():

            @vmap(in_axes=(None, sub_systems.electron_vmap, None, 0))
            @vmap(in_axes=(None, sub_systems.molecule_vmap, 0, 0))
            @jax.grad
            def jac_fn(params, systems, excitation, mask):
                return jnp.where(
                    mask,
                    log_p(
                        params,
                        systems.replace(excitations=excitation[..., None], mol_ids=(0,)),
                    ).sum(),
                    0,
                )

            jacs.append(
                jac_fn(
                    params,
                    sub_systems,
                    jnp.asarray(sub_systems.excitations),
                    sample_mask,
                ),
            )

        return jacs

    # TODO: Sample mask!
    def _reweighted_jac(self, params, systems, reweighting_factor):
        # sample_src is segmented in the batch dimension, but we want
        # the sample_src in the n_mols dimension here
        reweighting_factor = unpack_reweighted_tensor(reweighting_factor, systems)
        # (walker, n_mol_ids, sample_src, wf_state)

        jacs: list[jax.Array] = []

        for sub_systems, indices in zip(
            systems.iter_stacked_sub_systems(),
            systems.unique_indices,
            strict=True,
        ):

            @vmap(in_axes=(None, sub_systems.electron_vmap, 0, None))
            @vmap(in_axes=(None, sub_systems.molecule_vmap, 0, None))
            @jax.grad
            def jac_fn(params, systems, reweighting_factor, max_num_states: int):
                reparams = self.wave_function.reparams(
                    params,
                    systems.replace(mol_ids=(0,)),
                )

                def excitation_closure(ex):
                    sys = systems.replace(excitations=ex)
                    return self.wave_function.apply(params, sys, reparams).squeeze()

                logpsi = jax.vmap(excitation_closure)(
                    jnp.arange(max_num_states)[..., None],
                )  # (max_num_states,)
                logpsi *= reweighting_factor / max_num_states

                return 2 * logpsi.sum()

            # construct a n_mols axis from n_mol_ids and sample_src
            # then select the molecules for this sub_systems object
            reweighting = reweighting_factor[
                :,
                sub_systems.mol_ids,
                sub_systems.excitations,
                :,
            ][:, indices]

            jacs.append(
                jac_fn(
                    params,
                    sub_systems,
                    reweighting,
                    sub_systems.max_num_states,
                ),  # (walker, (n_mol_ids*sample_src = n_mols), params)
            )

        return jacs

    @jit(static_argnames=('sample_reweighting',))
    def apply(
        self,
        params: WaveFunctionParameters,
        systems: Systems,
        dL_dlogpsi: Float[Array, 'batch_size n_mols'],
        reweighting_factor: ReweightingFactor,
        sample_mask: Bool[Array, 'walker n_mol_ids sample_src'],
        state: KFACState,
        auxiliary_grads: WaveFunctionParameters,
        *,
        sample_reweighting: bool = False,
    ):
        out_dtypes = jax.tree.map(lambda x: x.dtype, params)

        damping = jnp.array(self.damping_schedule(state.step), dtype=jnp.float32)
        if self.dtype is not None:
            params, systems, dL_dlogpsi = tree_to_dtype(
                (params, systems, dL_dlogpsi),
                self.dtype,
            )

        gradient, _, identity_grad_aux = Identity(
            self.wave_function,
            dtype=self.dtype,
        ).apply(
            params,
            systems,
            dL_dlogpsi,
            reweighting_factor,
            sample_mask,
            None,
            auxiliary_grads,
            sample_reweighting=sample_reweighting,
        )
        gradient = tree_add(
            gradient,
            tree_mul(state.last_grad, self.decay_factor * damping),
        )

        @jit
        def merge_jacs(*jacs: jax.Array) -> jax.Array:
            return jnp.concatenate(jacs, axis=1)[:, systems.inverse_unique_indices]

        @jit
        def kfac(gradient: jax.Array, factors: KroneckerFactors):
            return factors(
                gradient,
                jnp.asarray(damping, dtype=jnp.float32),
                reconstruction_tol=self.reconstruction_tol,
            )

        if self.fisher_reweighting:
            jacs = self._reweighted_jac(params, systems, reweighting_factor)
        else:
            jacs = self._unreweighted_jac(
                params,
                systems,
                reweighting_factor,
                sample_mask,
            )
        jac = jax.tree.map(merge_jacs, *jacs)
        new_cov = jax.tree.map(
            lambda x: KroneckerFactors.from_jac(x, self.kronecker_rank, self.num_iters),
            jac,
        )
        new_factors = state.factors.update(new_cov, self.ema)
        natgrad = jax.tree.map(kfac, gradient, new_factors.value())
        natgrad = pmean(natgrad)

        # Norm constraint
        # https://jimmylba.github.io/papers/nsync.pdf
        lr = jnp.array(self.lr_schedule(state.step), dtype=jnp.float32)
        fisher_norm = tree_dot(gradient, natgrad)
        scale = jnp.minimum(1, jnp.sqrt(state.norm_constraint / fisher_norm) / lr)
        update = tree_mul(natgrad, scale * lr)
        # Convert back to the original dtype
        update = jax.tree.map(jax.lax.convert_element_type, update, out_dtypes)

        return (
            update,
            KFACState(
                factors=new_factors,
                last_grad=natgrad,
                norm_constraint=state.norm_constraint * (1 - self.norm_constraint_decay),
                step=jnp.array(state.step + 1, dtype=jnp.int32),
            ),
            {
                'identity_grad_norm': tree_squared_norm(gradient) ** 0.5,
                'natgrad_norm': tree_squared_norm(natgrad) ** 0.5,
                'scale': scale,
                'fisher_norm': fisher_norm**0.5,
            }
            | identity_grad_aux,
        )


PRECONDITIONER = Modules[Preconditioner](
    {cls.__name__.lower(): cls for cls in [Identity, CG, Spring, KFAC]},
)

try:
    import kfac_jax  # pyright: ignore[reportMissingImports]
    from kfac_jax import (  # pyright: ignore[reportMissingImports]
        OptaxPreconditioner,
        OptaxPreconditionState,
    )

    from neural_pfaffian.kfac_blocks import LayerNormBlock, RepeatedDenseBlock
    from neural_pfaffian.utils.jax_utils import register_kfac_state_with_flax

    class RealKFACState(PyTreeNode):
        state: OptaxPreconditionState
        step: Int[Array, '']

    register_kfac_state_with_flax()

    class RealKFAC(PyTreeNode, Preconditioner[RealKFACState]):
        wave_function: GeneralizedWaveFunction = field(pytree_node=False)
        lr_schedule: Schedule = field(pytree_node=False)
        damping: Float[ArrayLike, '']
        norm_constraint: Float[ArrayLike, '']
        ema: Float[ArrayLike, '']
        optimizer: OptaxPreconditioner = field(pytree_node=False)
        apply_mask_to_fisher: bool = field(pytree_node=False, default=False)

        @staticmethod
        def create(
            wave_function: GeneralizedWaveFunction,
            lr_schedule: ScheduleConfig,
            damping: Float[ArrayLike, ''],
            norm_constraint: Float[ArrayLike, ''],
            ema: Float[ArrayLike, ''],
            *,
            apply_mask_to_fisher: bool = False,
        ) -> 'RealKFAC':
            def _loss_fn(params, sys_and_mask):
                systems, sample_mask = sys_and_mask
                result = wave_function.batched_apply(params, systems)
                # We expect an unreweighted sample_mask here!
                if apply_mask_to_fisher:
                    result = jnp.where(sample_mask, result, 0.0)
                kfac_jax.register_normal_predictive_distribution(result)
                return result

            def _batch_size_extractor(sys_and_mask):
                systems, sample_mask = sys_and_mask
                return (
                    jnp.sum(sample_mask)
                    if apply_mask_to_fisher
                    else systems.electrons.shape[0] * systems.n_mols
                )

            optimizer = OptaxPreconditioner(
                _loss_fn,
                l2_reg=0,
                damping=damping,  # type:ignore
                norm_constraint=norm_constraint,  # type:ignore
                estimation_mode='fisher_exact',
                curvature_ema=ema,  # type: ignore
                inverse_update_period=1,
                pmap_axis_name=_BATCH_AXIS,
                batch_size_extractor=_batch_size_extractor,
                layer_tag_to_block_ctor={
                    'dense': RepeatedDenseBlock,
                    'scale_and_shift': LayerNormBlock,
                },
                auto_register_kwargs={'raise_error_on_diff_jaxpr': False},
            )

            return RealKFAC(
                wave_function=wave_function,
                lr_schedule=get_schedule(lr_schedule),
                damping=damping,
                norm_constraint=norm_constraint,
                ema=ema,
                optimizer=optimizer,
                apply_mask_to_fisher=apply_mask_to_fisher,
            )

        def init(
            self,
            key: Array,
            params: WaveFunctionParameters,
            systems: Systems,
        ) -> RealKFACState:
            return RealKFACState(
                self.optimizer.init(
                    (params, (systems, jnp.ones((), dtype=jnp.bool))),
                    key,
                ),  # type: ignore[type-args]
                step=jnp.zeros((), dtype=jnp.int32),
            )

        @jit(static_argnames=('sample_reweighting',))
        def apply(
            self,
            params: WaveFunctionParameters,
            systems: Systems,
            dL_dlogpsi: Float[Array, 'batch_size n_mols'],
            reweighting_factor: ReweightingFactor,
            sample_mask: SampleMask,
            state: RealKFACState,
            auxiliary_grads: WaveFunctionParameters,
            *,
            sample_reweighting: bool = False,
        ):
            gradient, _, identity_grad_aux = Identity(self.wave_function).apply(
                params,
                systems,
                dL_dlogpsi,
                reweighting_factor,
                sample_mask,
                None,
                auxiliary_grads,
                sample_reweighting=sample_reweighting,
            )
            gradient = tree_mul(
                gradient,
                jnp.array(self.lr_schedule(state.step), dtype=jnp.float32),
            )

            if sample_reweighting:
                sample_mask = convert_reweighted_tensor_to_unreweighted(
                    sample_mask,
                    systems,
                )

            kfac_state = state.state
            kfac_state = self.optimizer.maybe_update(
                kfac_state,
                (params, (systems, sample_mask)),  # type: ignore[arg-type]
                jax.random.key(0),
            )
            kfac_state = self.optimizer.increment_count(kfac_state)
            updates = self.optimizer.apply(gradient, kfac_state)  # type: ignore[arg-type]
            return (
                cast('WaveFunctionParameters', updates),
                RealKFACState(
                    state=kfac_state,
                    step=jnp.array(state.step + 1, dtype=jnp.int32),
                ),
                identity_grad_aux,
            )

    PRECONDITIONER['realkfac'] = RealKFAC
except ImportError:
    pass
