import einops
import numpy as onp
import pytest
from pyscf import dft
from utils import (
    PyscfSystemWrapper,
    assert_either_abs_or_rel_close,
    set_jax_testing_config,
    system_from_preloaded,
)

set_jax_testing_config()

import jax.numpy as jnp

from egxc.discretization import get_grid_fn, get_gto_grid_eval_fn
from egxc.discretization.gto.containers import GTOBasis, get_gto_preloader
from egxc.systems import PreloadSystem, System, examples
from egxc.utils.typing import cast_to_integer_tuple


def eval_basis_fns(basis: str, psys: PreloadSystem, deriv: int):
    unique_elements = set(psys.atom_z)
    max_angular_momentum, vec_basis_assembly = get_gto_preloader(basis, unique_elements)
    basis_fn = get_gto_grid_eval_fn(deriv, max_angular_momentum)
    preloaded_vec_basis_fns = vec_basis_assembly(psys.atom_z)
    vec_basis_fns = GTOBasis.from_preloaded(preloaded_vec_basis_fns)
    return basis_fn(
        psys.grid.coords,  # type: ignore
        psys.nuc_pos,  # type: ignore
        vec_basis_fns.radial_primitives,
        vec_basis_fns.compile_statics,
    )


def check_equal_ao_values(psys: PreloadSystem, deriv: int = 1):
    unique_elements = set(psys.atom_z)
    grid_fn = get_grid_fn(1, unique_elements, 512)
    sys = System.from_preloaded(psys, grid='')  # type: ignore

    nuc_pos = jnp.asarray(psys.nuc_pos)
    coords, weights = grid_fn(
        nuc_pos[psys.atom_mask],  # type: ignore
        cast_to_integer_tuple(psys.atom_z[psys.atom_mask]),
    )

    # Get the vec_basis_fns assembly function
    max_angular_momentum, vec_basis_assembly = get_gto_preloader(
        psys.basis, unique_elements
    )

    # Create the basis function with the correct max_angular_momentum
    basis_fn = get_gto_grid_eval_fn(
        deriv=deriv, max_angular_momentum=max_angular_momentum
    )

    # Get the vec_basis_fns and num_basis_fns
    preloaded_vec_basis_fns = vec_basis_assembly(psys.atom_z)
    vec_basis_fns = GTOBasis.from_preloaded(preloaded_vec_basis_fns)

    pyscf_mol = sys.to_pyscf(psys.basis)
    if deriv == 0:
        aos = basis_fn(
            coords,
            nuc_pos,
            vec_basis_fns.radial_primitives,
            vec_basis_fns.compile_statics,
        )
        pyscf_target_aos = pyscf_mol.eval_gto('GTOval_sph', coords)
    elif deriv == 1:
        aos, grad = basis_fn(
            coords,
            nuc_pos,
            vec_basis_fns.radial_primitives,
            vec_basis_fns.compile_statics,
        )
        aos = jnp.concatenate((aos[..., None], grad), axis=-1)  # type: ignore
        pyscf_target_aos = pyscf_mol.eval_gto('GTOval_sph_deriv1', coords)  # type: ignore
        aos = einops.rearrange(aos, 'n b f -> f n b')
    else:
        raise ValueError(f'Derivative order {deriv} not supported')

    assert (
        aos.shape == pyscf_target_aos.shape  # type: ignore
    ), f'Shapes do not match: {aos.shape} != {pyscf_target_aos.shape}'  # type: ignore

    assert_either_abs_or_rel_close(onp.asarray(aos), pyscf_target_aos)


@pytest.mark.quick
def test_ao_values_single_atom_without_gradient(basis: str = 'sto-3g'):
    psys = examples.get_preloaded('h', basis=basis, alignment=1)
    check_equal_ao_values(psys, deriv=0)


@pytest.mark.quick
def test_ao_values_single_atom(basis: str = 'sto-3g'):
    psys = examples.get_preloaded('h', basis=basis, alignment=1)
    check_equal_ao_values(psys)


@pytest.mark.quick
def test_ao_values_multiple_atoms(basis: str = 'sto-3g'):
    psys = examples.get_preloaded('organic', basis=basis, alignment=1)
    check_equal_ao_values(psys)


@pytest.mark.quick
def test_ao_values_with_padding(basis: str = 'sto-3g', align: int = 4):
    psys = examples.get_preloaded('h2', basis=basis, alignment=align)
    check_equal_ao_values(psys)


@pytest.mark.slow
def test_elements_basis_sets(z=18, basis='def2-TZVPD'):
    # test first 18 elements
    psys = examples.get_preloaded(z, basis=basis, alignment=1)
    check_equal_ao_values(psys)


@pytest.mark.slow
@pytest.mark.parametrize(
    'basis',
    [
        'sto-6g',
        '6-31G(d)',
        'def2-SVP',
        'def2-TZVP',
        'def2-TZVPD',
        '6-31G(2df,p)',
        '6-311+G(2df,p)',
        # 'cc-pVQZ',
    ],
)
def test_larger_basis_sets(basis: str):
    psys = examples.get_preloaded('organic', basis=basis, alignment=1)
    check_equal_ao_values(psys)


@pytest.mark.quick
def test_combination_of_quadrature_and_atomic_orbitals(basis='def2-TZVPD', xc='LDA,VWN'):
    psys = examples.get_preloaded('water', basis=basis, alignment=1)
    sys = system_from_preloaded(psys, basis, 1, 512)
    P = psys.initial_density_matrices[0]
    mol = PyscfSystemWrapper(sys, basis=basis, xc=xc)

    def compute():
        rho = dft.numint.eval_rho(mol.pyscf, onp.asarray(sys.grid.aos), P)
        e_xc = dft.libxc.eval_xc(xc, rho, deriv=0)[0]
        return e_xc, onp.sum(e_xc * rho * sys.grid.weights)

    e_xc, E_xc = compute()

    def compute_references():
        ref_c, ref_w = mol.quadrature_points_and_weights
        ref_ao = dft.numint.eval_ao(mol.pyscf, ref_c)
        ref_rho = dft.numint.eval_rho(
            mol.pyscf, ref_ao, P
        )  # TODO: mind that this assumes that the coordinates are in units of Bohr
        ref_e_xc = dft.libxc.eval_xc(xc, ref_rho, deriv=0)[0]
        return ref_e_xc, onp.sum(ref_e_xc * ref_rho * ref_w)

    ref_e_xc, ref_E_xc = compute_references()

    # TODO: is our implementation really that good?!
    assert 1 - E_xc / ref_E_xc < 1e-15, f'ref_E_xc: {ref_E_xc}, E_xc: {E_xc}'
    assert onp.abs(E_xc - ref_E_xc) < 1e-12, 'energy error should be less than 1e-12 Ha'
