from typing import Literal

import jax.numpy as jnp
import numpy as onp
import pytest
from pyscf import dft
from scipy import linalg as ref_linalg
from utils import call_module_as_function, set_jax_testing_config

from egxc.solver import fock, scf
from egxc.systems import examples
from egxc.systems.base import nuclear_energy_fn
from egxc.utils import linalg
from egxc.xc_energy.features import DensityFeatures
from egxc.xc_energy.functionals.classical import mgga

set_jax_testing_config()


@pytest.mark.quick
def test_generalized_eigenvalue_solver():
    onp.random.seed(0)
    n = 5
    F = onp.random.randint(-5, 5, size=(n, n))
    F = F + F.T  # symmetrize
    S = onp.random.randint(-2, 2, size=(n, n)).astype(onp.float64)
    S = S.T @ S  # symmetric positive definite
    X = linalg.transformation_matrix(jnp.array(S, dtype=jnp.float64))  # type: ignore
    e, C = linalg.modified_generalized_eigenvalue_problem(F, X)  # type: ignore
    e_ref, C_ref = ref_linalg.eigh(F, S)
    assert onp.allclose(e, e_ref), f'{e} != {e_ref}'
    # note that the sign of the eigenvectors is arbitrary
    assert onp.allclose(onp.abs(C), onp.abs(C_ref)), f'{C} != {C_ref}'


@pytest.mark.slow
@pytest.mark.parametrize(
    'spin_restricted', [True, False], ids=['restricted', 'unrestricted']
)
@pytest.mark.parametrize('conv_acc_method', ['Vanilla', 'DIIS'], ids=['Vanilla', 'DIIS'])
def test_scf_method(spin_restricted: bool, conv_acc_method: Literal['Vanilla', 'DIIS']):
    BASIS = '6-31G(d)'
    USE_DENSITY_FITTING = False
    CYCLES = 10
    xc_mod = fock.XCModule(mgga.MetaGGA('scan'), DensityFeatures(spin_restricted))

    scf_solver = scf.SelfConsistentFieldSolver(
        xc_mod, CYCLES, USE_DENSITY_FITTING, spin_restricted, conv_acc_method
    )

    sys = examples.get(
        'ethanol',
        BASIS,
        alignment=0,
        use_density_fitting=USE_DENSITY_FITTING,
        spin_restricted=spin_restricted,
    )
    mol = sys.to_pyscf(BASIS)
    if spin_restricted:
        mf = dft.RKS(mol, xc='SCAN')
    else:
        mf = dft.UKS(mol, xc='SCAN')

    if not conv_acc_method == 'DIIS':
        mf.diis = None
    mf.max_cycle = CYCLES
    mf.grids.level = 1
    mf.verbose = 4
    P_0 = mf.get_init_guess()
    mf.kernel()

    (e_hj, e_xc), density_matrices = call_module_as_function(
        scf_solver, P_0, sys, jit=True
    )

    e_tot = e_xc + e_hj + nuclear_energy_fn(sys._nuc_pos, sys)
    e_tot = e_tot[-1]  # energy of last cycle
    e_ref = mf.e_tot
    if conv_acc_method == 'Vanilla':  # Vanilla SCF is divergent
        assert abs(e_tot - e_ref) * 1e3 < 0.06, (  # mHa
            f'Total energy does not match {e_tot:.8e} != {e_ref:.8e}, difference {(e_tot - e_ref) * 1e3:.3e} mHa'
        )
    else:
        assert abs(e_tot - e_ref) * 1e3 < 0.002, (  # mHa
            f'Total energy does not match {e_tot:.8e} != {e_ref:.8e}, difference {(e_tot - e_ref) * 1e3:.3e} mHa'
        )
