import jax.numpy as jnp
import numpy as onp
import pytest
from pyscf import gto, scf
from utils import PyscfSystemWrapper as PySys
from utils import call_module_as_function, set_jax_testing_config

from egxc.systems import examples
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.functionals.classical import hybrid, range_separated_hybrid
from egxc.xc_energy.xc_module import XCModule

set_jax_testing_config()
pytestmark = pytest.mark.quick

BASIS = 'def2-SVP'

hybrid_cases = {
    'hf_x': ('HF_X', 'hf,'),
    'pbe0': ('PBE0', 'pbe0'),
    'b3lyp': ('B3LYP', 'b3lyp'),
}


@pytest.mark.parametrize('key,libxcstr', hybrid_cases.values(), ids=hybrid_cases.keys())
def test_hybrid_functionals(key, libxcstr):
    spin_restricted = True
    module = XCModule(
        hybrid.Hybrid(key, False, spin_restricted),
        DensityFeatures(spin_restricted=spin_restricted),
    )
    sys = examples.get('h2o', BASIS, alignment=1)
    pyscf_sys = PySys(
        sys,  # type: ignore
        BASIS,
        xc=libxcstr,
        grid_level=1,
        spin_restricted=spin_restricted,
    )
    eri = pyscf_sys.electron_repulsion_tensor
    P = pyscf_sys.density_matrix
    E_tot = call_module_as_function(module, P, sys.grid, eri_tensor=eri)  # type: ignore
    E_ref = pyscf_sys.xc_energy
    assert abs(E_tot - E_ref) / abs(E_ref) < 1e-6, f'E_tot: {E_tot}, E_ref: {E_ref}'  # type: ignore


@pytest.mark.quick
def test_range_separated_exact_exchange():
    omega = 0.3
    sr_frac = 0.25
    lr_frac = 0.75
    mol = gto.M(atom='H 0 0 0; H 0 0 1.4', basis='sto-3g')
    mf = scf.RHF(mol)
    mf.kernel()
    P = onp.asarray(mf.make_rdm1())
    eri_full = onp.asarray(mol.intor('int2e'))
    with mol.with_short_range_coulomb(omega):
        eri_sr = onp.asarray(mol.intor('int2e'))
    with mol.with_range_coulomb(omega):
        eri_lr = onp.asarray(mol.intor('int2e'))
    sys = examples.get(
        'h2',
        basis='sto-3g',
        alignment=1,
        include_grid=True,
        use_density_fitting=False,
        range_separation=omega,
    )
    e_sr = float(hybrid.exact_exchange(jnp.array(P), jnp.array(eri_sr), True))
    e_lr = float(hybrid.exact_exchange(jnp.array(P), jnp.array(eri_lr), True))
    ref = sr_frac * e_sr + lr_frac * e_lr

    e_full = float(hybrid.exact_exchange(jnp.array(P), jnp.array(eri_full), True))
    assert abs(e_full - e_sr - e_lr) < 1e-15, f'e_full: {e_full - e_sr - e_lr}'

    func = range_separated_hybrid.BaseRangeSeparatedHybrid(
        short_range_fraction=sr_frac,
        long_range_fraction=lr_frac,
        use_density_fitting=False,
        spin_restricted=True,
    )

    val = call_module_as_function(
        func,
        density_matrix=jnp.array(P, dtype=jnp.float64),
        eri_sr_tensor=sys.fock_tensors.eri_sr_tensor,
        eri_lr_tensor=sys.fock_tensors.eri_lr_tensor,
        method='exact_exchange_contribution',
    )
    assert abs(val - ref) < 1e-6  # type: ignore


@pytest.mark.slow
@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
@pytest.mark.parametrize('molecule', ['water'], ids=['water'])
@pytest.mark.parametrize(
    'basis', ['def2-SVP', 'def2-TZVPD'], ids=['def2-SVP', 'def2-TZVPD']
)
def test_wb97mv_energy(molecule: str, basis: str, spin_restricted: bool):
    sys = examples.get(
        molecule,
        basis=basis,
        alignment=1,
        include_grid=True,
        use_density_fitting=False,
        range_separation=0.3,
        spin_restricted=spin_restricted,
        grid_level=1,
    )

    # Run SCF to get converged density matrix
    mol = sys.to_pyscf(basis)
    if spin_restricted:
        mf = scf.RKS(mol)
    else:
        mf = scf.UKS(mol)
    mf.xc = 'wb97m-v'
    mf.kernel()
    P = onp.asarray(mf.make_rdm1())

    # Setup: use spin_resolved=True only for unrestricted calculations
    # For restricted (closed-shell), use spin_resolved=False to avoid artificial density splitting
    func = range_separated_hybrid.wB97M_V(
        use_density_fitting=False, spin_restricted=spin_restricted
    )
    use_spin_resolved = not spin_restricted  # Only for unrestricted
    module = XCModule(
        func,
        DensityFeatures(spin_restricted=spin_restricted, spin_resolved=use_spin_resolved),
    )

    energy = call_module_as_function(
        module,
        P,
        sys.grid,
        eri_sr_tensor=sys.fock_tensors.eri_sr_tensor,
        eri_lr_tensor=sys.fock_tensors.eri_lr_tensor,
        grid_coords=sys.grid.coords,
        grid_weights=sys.grid.weights,
        grid_aos=sys.grid.aos,
        grid_grad_aos=sys.grid.grad_aos,
        jit=True,
    )

    # Calculate reference XC energy by subtracting non-XC energy from total energy
    e_total = mf.energy_tot(dm=P)
    mf.xc = ''
    e_without_xc = mf.energy_tot(dm=P)
    mf.xc = 'wb97m-v'
    e_ref = e_total - e_without_xc

    delta_e = onp.abs(energy - e_ref) * 1e3  # type: ignore
    assert delta_e < 0.1, (
        f'E_tot: {energy:.6f} Eh, E_ref: {e_ref:.6f} Eh, error: {delta_e:.3f} mHa'
    )
