import jax
import jax.numpy as jnp
import optax
import pytest
from jax import random
from utils import set_jax_testing_config, system_from_preloaded

from egxc.solver import fock, scf
from egxc.systems import examples
from egxc.systems.base import nuclear_energy_fn
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.functionals.learnable import DEIXC, Nagai2020, Nagai2022, XCDiff

set_jax_testing_config()
pytestmark = pytest.mark.modelling

BASIS = '6-31G(d)'
TARGET_ENERGY = -76.38566321214728  # RKS B3LYP/6-31G(d) from PySCF


@pytest.mark.slow
@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
@pytest.mark.parametrize(
    'functional',
    [
        Nagai2020(hidden_dim=8),
        XCDiff(hidden_dim=8),
        Nagai2022(hidden_dim=8),
        DEIXC(4, 8, False),
    ],
    ids=['Nagai2020', 'XCDiff', 'Nagai2022', 'DEIXC'],
)
def test_overfit_h2o(spin_restricted, functional):
    use_density_fitting = True
    xc_mod = fock.XCModule(functional, DensityFeatures(spin_restricted))
    scf_solver = scf.SelfConsistentFieldSolver(
        xc_mod, 15, use_density_fitting, spin_restricted, 'DIIS'
    )
    psys = examples.get_preloaded(
        'water',
        BASIS,
        use_density_fitting=use_density_fitting,
        spin_restricted=spin_restricted,
        alignment=1,
    )

    sys = system_from_preloaded(psys, BASIS, 1, grid_alignment=512)

    params = scf_solver.init(random.PRNGKey(0), psys.initial_density_matrices[0], sys)

    def loss_fn(params):
        energies, _ = scf_solver.apply(params, psys.initial_density_matrices[0], sys)
        e_final = (energies[0] + energies[1])[-1] + nuclear_energy_fn(sys._nuc_pos, sys)
        loss = (e_final - TARGET_ENERGY) ** 2
        return loss, e_final

    opt = optax.adam(1e-3)
    opt_state = opt.init(params)

    @jax.jit
    def step(params, opt_state):
        (loss, energy), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
        updates, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        return params, opt_state, loss, energy

    loss_list = []
    for i in range(10):
        params, opt_state, loss, energy = step(params, opt_state)
        print(f'Iteration {i}, loss: {loss}, energy: {energy}')
        loss_list.append(loss)

    loss = jnp.array(loss_list)

    consecutive_decrease = jnp.all(loss[1:] < loss[:-1])
    total_decrease = (
        loss[-1] / loss[0] < 0.25
    )  # after 10 iterations, loss should be reduced by at least 75%
    assert consecutive_decrease or total_decrease, 'Loss is not decreasing'
