import jax
import jax.numpy as jnp
import optax
import pytest
from jax import random
from utils import (
    PyscfSystemWrapper,
    assert_is_close,
    call_module_as_function,
    set_jax_testing_config,
)

from egxc.solver import scf
from egxc.systems import examples
from egxc.systems.base import nuclear_energy_fn
from egxc.xc_energy import XCModule
from egxc.xc_energy.features import DensityFeatures
from egxc.xc_energy.functionals import EGXC, MetaGGA, XCDiff
from egxc.xc_energy.functionals.learnable.nn import (
    NequIP,
    Nequix,
    NumericDecoder,
    NumericEncoder,
    PaiNN,
    SpatialReweighting,
)

set_jax_testing_config()
pytestmark = [pytest.mark.modelling, pytest.mark.slow]


def __get_initial_values(
    mol_str: str, spin_restricted, use_density_fitting, basis='6-31G(d)'
):
    sys = examples.get(
        mol_str,
        basis=basis,
        use_density_fitting=use_density_fitting,
        spin_restricted=spin_restricted,
        alignment=1,  # FIXME: align throws error at the moment
    )
    py_sys = PyscfSystemWrapper(sys, basis, spin_restricted=spin_restricted, xc='SCAN')
    P0 = py_sys.initial_density_matrix
    P = py_sys.density_matrix
    approximate_e_xc = py_sys.xc_energy
    return P0, P, sys, approximate_e_xc


def test_embedding_integral_truncation():
    spin_restricted = False
    use_density_fitting = True
    _, P, sys, _ = __get_initial_values(
        '3bpa', spin_restricted, use_density_fitting, basis='sto-6g'
    )
    mask, feats = call_module_as_function(
        DensityFeatures(spin_restricted), P, sys.grid.aos, sys.grid.grad_aos
    )
    n = feats[0]  # type: ignore
    embedding = NumericEncoder(
        '0e + 1o', 5.0, _quadrature_points_per_atom_scaling=100, num_radial_filters=33
    )
    embedding_truncated = NumericEncoder('0e + 1o', 5.0, num_radial_filters=33)
    f, _ = call_module_as_function(
        embedding,
        sys._nuc_pos,
        sys.atom_mask,
        sys.grid.coords,
        sys.grid.weights * mask,
        n,
    )
    f_truncated, _ = call_module_as_function(
        embedding_truncated,
        sys._nuc_pos,
        sys.atom_mask,
        sys.grid.coords,
        sys.grid.weights * mask,
        n,
    )
    s = f.filter('0e').mul_to_axis().array[..., 0]  # type: ignore
    s_t = f_truncated.filter('0e').mul_to_axis().array[..., 0]  # type: ignore
    assert_is_close(s_t, s, tolerance=1e-3)

    v = f.filter('1o').mul_to_axis().array
    v_t = f_truncated.filter('1o').mul_to_axis().array
    assert_is_close(v_t, v, tolerance=4e-3)


@pytest.mark.parametrize(
    'functional',
    [
        EGXC(
            MetaGGA('scan'),
            NumericEncoder('0e + 1o', 5.0, num_radial_filters=33),
            PaiNN(
                irreps_str='16x0e + 16x1o',
                output_irreps_str='12x0e + 12x1o',
                message_cutoff=5.0,
                layers=1,
                energy_graph_readout_hidden_dims=(16, 16, 1),
                n_radial_basis=20,
                init_graph_readout_to_zero=True,
            ),
            non_local_grid_feature_mode='reweighting_with_mGGA_feats',
            decoder=NumericDecoder(spatial_feature_dim=12),
            non_local_reweighting=SpatialReweighting(2, 8),
        ),
        EGXC(
            MetaGGA('scan'),
            NumericEncoder('0e + 1o + 2e', 5.0, num_radial_filters=33),
            NequIP(
                output_irreps_str='8x0e + 4x1o + 2x2e',
                message_cutoff=5.0,
                layers=1,
                energy_graph_readout_hidden_dims=(16, 16, 1),
                n_radial_basis=8,
                irreps_str='8x0e + 6x1o + 4x2e',
                init_graph_readout_to_zero=True,
            ),
            non_local_grid_feature_mode='reweighting_with_mGGA_feats',
            decoder=NumericDecoder(spatial_feature_dim=8),
            non_local_reweighting=SpatialReweighting(2, 8),
        ),
    ],
    ids=['PaiNN', 'NequIP'],
)
@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
def test_xc_energy_eval(functional, spin_restricted: bool):
    xc_mod = XCModule(
        functional,
        DensityFeatures(spin_restricted),
    )
    use_density_fitting = True
    _, P, sys, e_xc_ref = __get_initial_values(
        'water', spin_restricted, use_density_fitting
    )

    non_local_kwargs = {}
    non_local_kwargs['atom_mask'] = sys.atom_mask
    non_local_kwargs['nuc_pos'] = sys._nuc_pos
    non_local_kwargs['grid_coords'] = sys.grid.coords

    params = xc_mod.init(random.PRNGKey(0), P, sys.grid, **non_local_kwargs)

    @jax.jit
    def energy_fn(params):
        return xc_mod.apply(params, P, sys.grid, **non_local_kwargs)

    e_xc = energy_fn(params)

    assert abs(e_xc - e_xc_ref) < 0.5, (
        f'Implausible energy of untrained network {e_xc}, (reference {e_xc_ref})'
    )


@pytest.mark.parametrize(
    'functional,learning_rate',
    [
        (
            EGXC(
                XCDiff(n_layers=3, hidden_dim=8),
                NumericEncoder('0e + 1o', 4.0, num_radial_filters=33),
                PaiNN(
                    irreps_str='12x0e + 12x1o',
                    output_irreps_str='8x0e + 8x1o',
                    message_cutoff=4.0,
                    layers=1,
                    energy_graph_readout_hidden_dims=(12, 12, 1),
                    n_radial_basis=20,
                    init_graph_readout_to_zero=True,
                ),
                'reweighting_without_mGGA_feats',
                NumericDecoder(spatial_feature_dim=8),
                SpatialReweighting(2, 4),
            ),
            1e-3,
        ),
        (
            EGXC(
                XCDiff(n_layers=3, hidden_dim=8),
                NumericEncoder(
                    '0e + 1o',
                    4.0,
                    nuclei_partitioning='Exponential',
                    num_radial_filters=33,
                ),
                PaiNN(
                    irreps_str='12x0e + 12x1o',
                    output_irreps_str='8x0e + 8x1o',
                    message_cutoff=4.0,
                    layers=1,
                    energy_graph_readout_hidden_dims=(12, 12, 1),
                    n_radial_basis=20,
                    init_graph_readout_to_zero=True,
                ),
                'reweighting_without_mGGA_feats',
                NumericDecoder(spatial_feature_dim=8),
                SpatialReweighting(2, 4),
            ),
            1e-3,
        ),
        (
            EGXC(
                XCDiff(n_layers=3, hidden_dim=8),
                NumericEncoder(
                    '0e + 1o', 4.0, nuclei_partitioning='Gaussian', num_radial_filters=33
                ),
                PaiNN(
                    irreps_str='12x0e + 12x1o',
                    output_irreps_str='8x0e + 8x1o',
                    message_cutoff=4.0,
                    layers=1,
                    energy_graph_readout_hidden_dims=(12, 12, 1),
                    n_radial_basis=20,
                    init_graph_readout_to_zero=True,
                ),
                'reweighting_without_mGGA_feats',
                NumericDecoder(spatial_feature_dim=8),
                SpatialReweighting(2, 4),
            ),
            1e-3,
        ),
        (
            EGXC(
                XCDiff(n_layers=3, hidden_dim=8),
                NumericEncoder(
                    '0e + 1o', 4.0, num_radial_filters=11, radial_basis_type='polynomial'
                ),
                PaiNN(
                    irreps_str='12x0e + 12x1o',
                    output_irreps_str='8x0e + 8x1o',
                    message_cutoff=4.0,
                    layers=1,
                    energy_graph_readout_hidden_dims=(12, 12, 1),
                    n_radial_basis=20,
                    init_graph_readout_to_zero=True,
                ),
                'reweighting_without_mGGA_feats',
                NumericDecoder(spatial_feature_dim=8),
                SpatialReweighting(2, 4),
            ),
            1e-4,
        ),
        (
            EGXC(
                XCDiff(n_layers=3, hidden_dim=8),
                NumericEncoder('0e + 1o + 2e', 4.0, num_radial_filters=33),  # base irreps
                NequIP(
                    output_irreps_str='8x0e + 8x1o + 8x2e',
                    message_cutoff=4.0,
                    layers=1,
                    energy_graph_readout_hidden_dims=(12, 12, 1),
                    n_radial_basis=8,
                    irreps_str='12x0e + 12x1o + 12x2e',
                    init_graph_readout_to_zero=True,
                ),
                'reweighting_without_mGGA_feats',
                NumericDecoder(spatial_feature_dim=8),
                SpatialReweighting(2, 4),
            ),
            1e-3,
        ),
        (
            EGXC(
                XCDiff(n_layers=3, hidden_dim=8),
                NumericEncoder('0e + 1o', 4.0, num_radial_filters=16),  # base irreps
                Nequix(
                    output_irreps_str='8x0e + 8x1o',
                    message_cutoff=4.0,
                    layers=2,
                    energy_graph_readout_hidden_dims=(12, 1),
                    n_radial_basis=8,
                    irreps_str='12x0e + 12x1o',
                    init_graph_readout_to_zero=True,
                ),
                'reweighting_without_mGGA_feats',
                NumericDecoder(spatial_feature_dim=8),
                SpatialReweighting(2, 4),
            ),
            1e-5,  # Lower learning rate for Nequix stability
        ),
    ],
    ids=[
        'EG-XC:PaiNN+XCDiff:Default',  # Iteration 9, loss: 0.0844
        'EG-XC:PaiNN+XCDiff:ExpPartitioning',  # Iteration 9, loss: 0.0971
        'EG-XC:PaiNN+XCDiff:GaussPartitioning',  # Iteration 9, loss: 0.1467
        'EG-XC:PaiNN+XCDiff:PolynomialRadialBasis',  # Iteration 9, loss: 0.30108
        'EG-XC:NequIP+XCDiff:Default',  # Iteration 9, loss: 0.02278
        'EG-XC:Nequix+XCDiff:Default',  # Iteration 9, loss: 0.02278
    ],
)
def test_overfit_water(functional, learning_rate, spin_restricted=False):
    use_density_fitting = True
    P0, _, sys, _ = __get_initial_values('water', spin_restricted, use_density_fitting)
    xc_mod = XCModule(functional, DensityFeatures(spin_restricted))
    scf_solver = scf.SelfConsistentFieldSolver(
        xc_mod, 15, use_density_fitting, spin_restricted, 'DIIS'
    )

    params = scf_solver.init(random.PRNGKey(0), P0, sys)

    def loss_fn(params):
        TARGET_ENERGY = -76.38566321214728  # RKS B3LYP/6-31G(d) from PySCF
        energies, _ = scf_solver.apply(params, P0, sys)
        e_final = (energies[0] + energies[1])[-1] + nuclear_energy_fn(sys._nuc_pos, sys)
        loss = (e_final - TARGET_ENERGY) ** 2
        return loss, e_final

    opt = optax.adam(
        learning_rate
    )  # deliberately small learning rate, to guarantee loss decrease
    opt_state = opt.init(params)

    @jax.jit
    def step(params, opt_state):
        (loss, energy), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss, energy

    loss_list = []
    for i in range(10):
        params, opt_state, loss, energy = step(params, opt_state)
        print(f'Iteration {i}, loss: {loss}, energy: {energy}')
        loss_list.append(loss)

    loss = jnp.array(loss_list)

    consecutive_decrease = jnp.all(loss[1:] < loss[:-1])
    total_decrease = (
        loss[-1] / loss[0] < 0.1
    )  # after 10 iterations, loss should be reduced by at least 90%
    assert consecutive_decrease or total_decrease, 'Loss is not decreasing'
