import jax.numpy as jnp
import numpy as onp
import pytest
from pyscf.dft import numint
from utils import assert_is_close, set_jax_testing_config
from utils.pyscf_wrapper import PyscfSystemWrapper, XcType

from egxc.systems.examples import get
from egxc.xc_energy.functionals.dispersion import vv10

set_jax_testing_config()
pytestmark = pytest.mark.quick


BASIS = 'def2-svp'


def test_vv10_kernel_against_pyscf():
    # Build water system and get density and quadrature grid using wrapper
    sys = get('water', basis=BASIS, include_grid=True)
    pyscf_wrapped_sys = PyscfSystemWrapper(
        sys, basis=BASIS, xc='LDA', spin_restricted=True, grid_level=3
    )
    rho = pyscf_wrapped_sys.get_density_features(xctype=XcType.gga, format='pyscf')
    coords, weights = pyscf_wrapped_sys.quadrature_points_and_weights
    params = vv10.VV10_PARAMS

    exc_ref, vxc_ref = numint._vv10nlc(rho, coords, rho, weights, coords, params)
    density = jnp.array(rho[0].T)  # type: ignore
    abs_grad_density = jnp.linalg.norm(jnp.array(rho[1:4].T), axis=-1)  # type: ignore
    exc, vxc = vv10.vv10_kernel(
        density,
        abs_grad_density,
        jnp.array(coords),
        jnp.array(weights),
        params,
    )

    assert_is_close(exc, jnp.array(exc_ref), name='exc', tolerance=2e-9)
    assert_is_close(vxc[0], jnp.array(vxc_ref[0]), name='vrho')
    assert_is_close(vxc[1], jnp.array(vxc_ref[1]), name='vsigma')


def test_vv10_energy_and_potential():  # TODO: fix
    sys = get('water', basis=BASIS, include_grid=True)
    pyscf_wrapped_sys = PyscfSystemWrapper(
        sys, basis=BASIS, xc='', spin_restricted=True, grid_level=3
    )

    mf = pyscf_wrapped_sys.mf
    P = pyscf_wrapped_sys.density_matrix

    n, e_vv10_ref, v_vv10_ref = mf._numint.nr_nlc_vxc(
        pyscf_wrapped_sys.mol,
        mf.nlcgrids,
        'vv10',
        P,
    )
    e_vv10 = vv10.vv10_energy(
        P,  # type: ignore
        sys.grid.coords,
        sys.grid.weights,
        sys.grid.aos,
        sys.grid.grad_aos,  # type: ignore
    )
    relative_error = abs(e_vv10 - e_vv10_ref) / abs(e_vv10_ref)
    assert relative_error < 1e-6, (
        f'energy mismatch: {e_vv10} != {e_vv10_ref}, relative error: {relative_error}'
    )

    v_vv10 = vv10.vv10_potential(
        P,  # type: ignore
        sys.grid.coords,
        sys.grid.weights,
        sys.grid.aos,
        sys.grid.grad_aos,  # type: ignore
    )
    print(v_vv10[5])
    print(v_vv10_ref[5])
    assert_is_close(v_vv10, jnp.asarray(v_vv10_ref), name='potential', tolerance=6e-2)
    assert_is_close(
        v_vv10, jnp.asarray(v_vv10_ref), name='potential', tolerance=5e-7, absolute=True
    )

    # since relative and absolute errors are rather large, we compare eigenvalues too
    eigvals = onp.linalg.eigvals(v_vv10)
    eigvals_ref = onp.linalg.eigvals(v_vv10_ref)
    assert_is_close(
        eigvals, eigvals_ref, name='eigenvalues', tolerance=1e-2, absolute=True
    )

    # since the eigenvalues are not exactly the same, we compare the homo-lumo gap
    N_elec = mf.mol.nelectron
    homo_lumo_gap = eigvals[N_elec // 2 - 1] - eigvals[N_elec // 2]
    homo_lumo_gap_ref = eigvals_ref[N_elec // 2 - 1] - eigvals_ref[N_elec // 2]
    print(homo_lumo_gap)
    print(homo_lumo_gap_ref)
    assert_is_close(
        homo_lumo_gap,
        homo_lumo_gap_ref,
        name='homo-lumo gap',
        tolerance=1e-8,
        absolute=True,
    )
    assert_is_close(
        homo_lumo_gap,
        homo_lumo_gap_ref,
        name='homo-lumo gap',
        tolerance=1e-8,
        absolute=False,
    )

    assert False  # TODO: fix discrepancy in the potential compared to pyscf, but the matrices look very similar
