import jax
import jax.numpy as jnp
import pytest
from utils import PyscfSystemWrapper as PySys
from utils import (
    assert_is_close,
    set_jax_testing_config,
    system_from_preloaded,
)

from egxc.solver import fock
from egxc.systems import examples
from egxc.xc_energy.features import DensityFeatures
from egxc.xc_energy.functionals.classical import MetaGGA

set_jax_testing_config()
jax.config.update('jax_debug_nans', True)

BASIS = '6-31G(d)'


@pytest.mark.quick
@pytest.mark.parametrize('align', [1, 4], ids=['without_padding', 'with_padding'])
def test_xc_potential(align: int):
    spin_restricted = True
    xc_module = fock.XCModule(
        MetaGGA('scan'),
        DensityFeatures(spin_restricted),
    )
    psys = examples.get_preloaded('h2', BASIS, align)
    sys = system_from_preloaded(psys, BASIS, 1, align)
    P = psys.initial_density_matrices[0]
    xc_module.init(jax.random.PRNGKey(0), P, sys.grid)

    assert xc_module.apply({}, P, sys.grid) != jnp.nan
    V_xc = xc_module.apply(
        {}, P, sys.grid, sys.fock_tensors.basis_mask, method=xc_module.xc_potential
    )

    pyscf_sys = PySys(
        sys,
        BASIS,
        xc='SCAN',  # TODO try lda
        grid_level=1,
        spin_restricted=spin_restricted,
    )
    P = P[sys.fock_tensors.basis_mask, :][:, sys.fock_tensors.basis_mask]
    V_xc_ref = pyscf_sys.xc_potential(P)
    V_xc = V_xc[sys.fock_tensors.basis_mask, :][:, sys.fock_tensors.basis_mask]  # type: ignore
    # TODO: should the indices belonging to padded basis functions be zero of V_xc?
    assert_is_close(
        V_xc,
        V_xc_ref,  # type: ignore
        name='exchange correlation potential',
        tolerance=1e-6,
        absolute=True,
    )  # type: ignore
