from functools import partial

import jax

from egxc.utils import linalg
from egxc.utils.typing import Float1, Float2xBxB, FloatBxB


@partial(jax.jit, static_argnames=['n_elec', 'spin_restricted'])
def homo_lumo_gap_fn(
    fock_matrix: FloatBxB | Float2xBxB,
    diagonal_overlap: FloatBxB,
    n_elec: int,
    spin_restricted: bool = True,
) -> Float1:
    """
    Compute the KS HOMO–LUMO gap (RKS).
    Where eps = diag(C^T F C) are the orbital energies at SCF.

    Returns a scalar JAX array with the gap = eps_LUMO - eps_HOMO.
    """
    if spin_restricted:
        assert fock_matrix.ndim == 2, 'fock_matrix must be a 2D array'
    else:
        assert fock_matrix.ndim == 3, 'fock_matrix must be a 3D array'
        raise NotImplementedError('Spin-unrestricted systems are not supported yet.')

    eps, _ = linalg.modified_generalized_eigenvalue_problem(fock_matrix, diagonal_overlap)
    n_occ = n_elec // 2
    gap = eps[n_occ] - eps[n_occ - 1]
    return gap
