from functools import partial

import flax.linen as nn
import jax.numpy as jnp
import numpy as onp
import pytest
from pyscf import dft
from utils import PyscfSystemWrapper as PySys
from utils import (
    assert_is_close,
    call_module_as_function,
    set_jax_testing_config,
    system_from_preloaded,
    vec_basis_fns_from_preloaded,
)

from egxc.discretization import GTOBasis, get_gto_grid_eval_fn, get_gto_preloader
from egxc.systems import System, examples
from egxc.utils.typing import FloatBxB
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.features import ueg_e_x
from egxc.xc_energy.functionals.classical import (
    gga,
    lda,
    lsda,
    mgga,
)

set_jax_testing_config()
pytestmark = pytest.mark.quick


BASIS = '6-31G(d)'
L_MAX, pvec_basis_fn_factory = get_gto_preloader(BASIS, {1, 6, 7, 8})


class FeatureFactory(nn.Module):
    spin_restricted: bool

    def setup(self):
        self.basis_fn = get_gto_grid_eval_fn(1, L_MAX)
        self.features = DensityFeatures(self.spin_restricted)

    def __call__(self, sys: System, vec_basis_fns: GTOBasis, density_matrix: FloatBxB):
        aos, grad_aos = self.basis_fn(
            sys.grid.coords,
            sys._nuc_pos,
            vec_basis_fns.radial_primitives,
            vec_basis_fns.compile_statics,
        )
        mask, feats = self.features(density_matrix, aos, grad_aos)
        return sys.grid.weights, mask, feats


@pytest.mark.parametrize('align', [1, 4], ids=['without_padding', 'with_padding'])
def test_density_features(align: int, spin_restricted: bool = True):
    factory = FeatureFactory(spin_restricted=spin_restricted)
    psys = examples.get_preloaded(
        'h2', BASIS, alignment=align, spin_restricted=spin_restricted
    )
    _, vec_basis_fns = vec_basis_fns_from_preloaded(psys, BASIS)
    sys = system_from_preloaded(psys, BASIS, 1, 1)
    P = psys.initial_density_matrices[0]

    weights, mask, feats = call_module_as_function(factory, sys, vec_basis_fns, P)  # type: ignore
    n, zeta, s, tau = feats
    assert jnp.all(zeta == 0)

    pyscf_sys = PySys(sys, BASIS, grid_level=1, spin_restricted=spin_restricted)
    b_mask = sys.fock_tensors.basis_mask
    P = P[:, b_mask][b_mask, :]
    ref_n, _, ref_s, ref_tau = pyscf_sys.get_density_features(
        P,
        coords=sys.grid.coords,  # type: ignore
        format='egxc',
    )

    n_pad = n.shape[0] - ref_n.shape[0]  #  type: ignore
    ref_n = onp.pad(ref_n, ((0, n_pad)))

    ref_n = onp.clip(ref_n, 1e-15, None)

    mask = mask & (weights > 0)

    assert_is_close(n, ref_n, mask, name='density')  # type: ignore
    assert_is_close(s, ref_s, mask, name='reduced gradient')  # type: ignore
    assert_is_close(tau, ref_tau, mask, name='kinetic energy density')  # type: ignore


ethanol_psys = examples.get_preloaded('ethanol', BASIS, alignment=1)
ethanol = system_from_preloaded(ethanol_psys, BASIS, 1, 512)
_, ethanol_vec_basis_fns = vec_basis_fns_from_preloaded(ethanol_psys, BASIS)
ethanol_pyscf_sys = PySys(
    ethanol,  # type: ignore
    BASIS,
    grid_level=1,
    spin_restricted=False,
)


@pytest.fixture
def ethanol_features():
    factory = FeatureFactory(spin_restricted=False)
    weights, mask, feats = call_module_as_function(  # type: ignore
        factory, ethanol, ethanol_vec_basis_fns, ethanol_pyscf_sys.density_matrix
    )
    ref_feats = ethanol_pyscf_sys.get_density_features(coords=ethanol.grid.coords)  # type: ignore
    yield weights, mask, feats, ref_feats


def __truncate_feats(feats, int):
    """
    n, dn_dx, dn_dy, dn_dz, tau
    """
    up_feats, down_feats = feats
    return up_feats[:int], down_feats[:int]


def __ref_feats_to_n(ref_feats):
    n_up, n_down = __truncate_feats(ref_feats, 1)
    return n_up + n_down


def assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-7):
    val = (mask * e * n * weights).sum()
    target = (mask * e_target * n_target * weights).sum()
    assert_is_close(val, target, tolerance=tolerance)


lda_cases = {
    'lda_exchange': (ueg_e_x, 'slater,'),
    'lda_correlation_pz': (lda.pz81_correlation_energy_density, ',lda_c_pz'),
    'lda_correlation_vwn5': (
        partial(lsda.vwn5_correlation_energy_density, zeta=0),  # type: ignore
        ',vwn5',
    ),
}


@pytest.mark.parametrize('xcfunc,libxcstr', lda_cases.values(), ids=lda_cases.keys())
def test_lda_energy_densities(ethanol_features, xcfunc, libxcstr):
    weights, mask, feats, ref_feats = ethanol_features
    n = feats[0]
    e = xcfunc(n)
    n_target = __ref_feats_to_n(ref_feats)
    e_target = dft.libxc.eval_xc(libxcstr, __truncate_feats(ref_feats, 1), spin=1)[0]
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-6)


lsda_cases = {
    'lsda_correlation_pw92': (lsda.pw92_correlation_energy_density, ',lda_c_pw'),
}


@pytest.mark.parametrize('xcfunc,libxcstr', lsda_cases.values(), ids=lsda_cases.keys())
def test_lsda_energy_densities(ethanol_features, xcfunc, libxcstr):
    weights, mask, feats, ref_feats = ethanol_features
    n, zeta, _, _ = feats
    e = xcfunc(n, zeta)
    n_target = __ref_feats_to_n(ref_feats)
    e_target = dft.libxc.eval_xc(libxcstr, __truncate_feats(ref_feats, 1), spin=1)[0]
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-7)


gga_cases = {
    'gga_exchange_b88': (gga.e_x_b88, 'gga_x_b88,'),
    'gga_exchange_pbe': (gga.e_x_pbe, 'gga_x_pbe,'),
    'gga_correlation_pbe': (gga.e_c_pbe, ',gga_c_pbe'),
    'gga_correlation_lyp': (gga.e_c_lyp, ',gga_c_lyp'),
}


@pytest.mark.parametrize('xcfunc,libxcstr', gga_cases.values(), ids=gga_cases.keys())
def test_gga_energy_densities(ethanol_features, xcfunc, libxcstr):
    # FIXME: why is the error that large? (still within 0.0X mHartree)
    weights, mask, feats, ref_feats = ethanol_features
    n, zeta, s, _ = feats
    e = xcfunc(n, zeta, s)
    n_target = __ref_feats_to_n(ref_feats)
    e_target = dft.libxc.eval_xc(libxcstr, __truncate_feats(ref_feats, 4), spin=1)[0]
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-5)


mgga_cases = {
    'scan': (mgga.e_xc_scan, 'scan'),
}


@pytest.mark.parametrize('xcfunc,libxcstr', mgga_cases.values(), ids=mgga_cases.keys())
def test_mgga_energy_densities(ethanol_features, xcfunc, libxcstr):
    weights, mask, feats, ref_feats = ethanol_features
    n = feats[0]
    e = xcfunc(*feats)
    n_target = __ref_feats_to_n(ref_feats)
    e_target = dft.libxc.eval_xc(libxcstr, ref_feats, spin=1)[0]
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-7)


# Spin-polarized tests for a molecule with unpaired electron (hydrogen atom)
hydrogen = examples.get_preloaded('h', BASIS, alignment=1)
_, hydrogen_vec_basis_fns = vec_basis_fns_from_preloaded(hydrogen, BASIS)
hydrogen = system_from_preloaded(hydrogen, BASIS, 1, 1)
hydrogen_pyscf_sys = PySys(
    hydrogen,  # type: ignore
    BASIS,
    xc='wb97m-v',
    spin_restricted=False,
    grid_level=1,
)


@pytest.fixture
def hydrogen_features():
    factory = FeatureFactory(spin_restricted=False)
    weights, mask, feats = call_module_as_function(  # type: ignore
        factory, hydrogen, hydrogen_vec_basis_fns, hydrogen_pyscf_sys.density_matrix
    )
    ref_feats = hydrogen_pyscf_sys.get_density_features(coords=hydrogen.grid.coords)  # type: ignore
    yield weights, mask, feats, ref_feats


@pytest.mark.parametrize('xcfunc,libxcstr', lda_cases.values(), ids=lda_cases.keys())
def test_lda_spinpolarized_energy_densities(hydrogen_features, xcfunc, libxcstr):
    weights, mask, feats, ref_feats = hydrogen_features
    n = feats[0]
    n_target = __ref_feats_to_n(ref_feats)
    e = xcfunc(n)
    # LDA formulas here are unpolarized; compare against libxc with total density and spin=0
    e_target = dft.libxc.eval_xc(libxcstr, n_target, spin=0)[0]
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-6)


@pytest.mark.parametrize('xcfunc,libxcstr', lsda_cases.values(), ids=lsda_cases.keys())
def test_lsda_spinpolarized_energy_densities(hydrogen_features, xcfunc, libxcstr):
    weights, mask, feats, ref_feats = hydrogen_features
    n, zeta, _, _ = feats
    e = xcfunc(n, zeta)
    e_target = dft.libxc.eval_xc(libxcstr, __truncate_feats(ref_feats, 1), spin=1)[0]
    n_target = __ref_feats_to_n(ref_feats)
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-7)


@pytest.mark.parametrize('xcfunc,libxcstr', gga_cases.values(), ids=gga_cases.keys())
def test_gga_spinpolarized_energy_densities(hydrogen_features, xcfunc, libxcstr):
    weights, mask, feats, ref_feats = hydrogen_features
    n, zeta, s, _ = feats
    e = xcfunc(n, zeta, s)
    n_target = __ref_feats_to_n(ref_feats)
    e_target = dft.libxc.eval_xc(libxcstr, __truncate_feats(ref_feats, 4), spin=1)[0]
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-5)


@pytest.mark.parametrize('xcfunc,libxcstr', mgga_cases.values(), ids=mgga_cases.keys())
def test_mgga_spinpolarized_energy_densities(hydrogen_features, xcfunc, libxcstr):
    weights, mask, feats, ref_feats = hydrogen_features
    n = feats[0]
    e = xcfunc(*feats)
    n_target = __ref_feats_to_n(ref_feats)
    e_target = dft.libxc.eval_xc(libxcstr, ref_feats, spin=1)[0]
    assert_total_xc_energy(mask, e, n, e_target, n_target, weights, tolerance=1e-7)
