import flax.linen as nn
import jax
from jax import config

from .errors import (
    assert_either_abs_or_rel_close,
    assert_is_close,
    is_close,
    relative_error,
)
from .pyscf_wrapper import PyscfSystemWrapper
from .sample_utils import (
    preload_grid_using_pyscf,
    system_from_preloaded,
    vec_basis_fns_from_preloaded,
)

ROOT_DATA_DIR = 'ANONYMOUS_DIR'


def set_jax_testing_config():
    config.update('jax_enable_x64', True)
    config.update('jax_default_matmul_precision', 'float32')
    config.update('jax_platform_name', 'cpu')


def call_module_as_function(
    module: nn.Module, *args, method: str | None = None, jit=False, **kwargs
):
    """
    Calls a Flax module without learnable parameters as a function.
    """
    params = module.init(jax.random.PRNGKey(0), *args, method=method, **kwargs)

    def apply(*args, **kwargs):
        return module.apply(params, *args, method=method, **kwargs)

    if jit:
        apply_jit = jax.jit(apply)
        return apply_jit(*args, **kwargs)
    out = apply(*args, **kwargs)
    return out
