import pytest
from utils import PyscfSystemWrapper, call_module_as_function, set_jax_testing_config

from egxc.systems import examples
from egxc.xc_energy.features import DensityFeatures
from egxc.xc_energy.functionals.learnable.nn import NumericDecoder, NumericEncoder, PaiNN

set_jax_testing_config()
pytestmark = pytest.mark.modelling


@pytest.mark.quick
def test_l1():
    cutoff = 5.0
    F = 128
    F_OUT = 16
    N_RBF = 32
    encoder = NumericEncoder('0e + 1o', cutoff, N_RBF)
    density_feat_fn = DensityFeatures(spin_restricted=True)

    sys = examples.get('water', 'sto-3g', alignment=0)
    P = PyscfSystemWrapper(sys, basis='sto-3g').density_matrix
    _, (n, _) = call_module_as_function(density_feat_fn, P, sys.grid.aos, None)
    args = (sys._nuc_pos, sys.atom_mask, sys.grid.coords, sys.grid.weights, n)

    node_feats, cache = call_module_as_function(encoder, *args)
    assert node_feats.shape == (3, N_RBF * 4)

    gnn = PaiNN(
        irreps_str=f'{F}x0e + {F}x1o',
        output_irreps_str=f'{F_OUT}x0e + {F_OUT}x1o',
        message_cutoff=cutoff,
        layers=1,
        energy_graph_readout_hidden_dims=(F, F, 1),
        init_graph_readout_to_zero=True,
    )
    _, node_feats = call_module_as_function(gnn, node_feats, sys._nuc_pos, sys.atom_mask)

    assert node_feats.shape == (3, F_OUT * 4)  # type: ignore

    decoder = NumericDecoder(F_OUT)
    spatial_feats = call_module_as_function(decoder, node_feats, cache)

    assert spatial_feats.shape == (len(sys.grid.weights), F_OUT)  # type: ignore
