def run_scan_reference(
    *,
    checkpoint_dir: str,
    molecule: str,
    nstates: int,
    conv_tol: float,
    basis: str | None,
    grid_level: int | None,
    cycles: int | None,
):
    """Run TDDFT/TDA linearized around a SCAN reference ground state (repo SCF + EGXC operator)."""

    ckp_dir = Path(checkpoint_dir)
    cfg = _load_checkpoint_config(ckp_dir)
    a = cfg['alignment']
    alignment = Alignment(int(a['atom']), int(a['basis']), int(a['grid']))

    basis_eff = basis if basis is not None else cfg['basis']['name']
    grid_level_eff = (
        int(grid_level) if grid_level is not None else int(cfg['quadrature']['level'])
    )
    cycles_eff = (
        int(cycles)
        if cycles is not None
        else int(cfg.get('solver', {}).get('kwargs', {}).get('cycles', 15))
    )

    sys = _build_system(
        molecule,
        basis=basis_eff,
        grid_level=grid_level_eff,
        alignment=alignment,
    )
    mol = sys.to_pyscf(basis_eff)

    functional = get_functional(
        'scan',
        spin_restricted=True,
        use_density_fitting=False,
    )
    xc_module = XCModule(functional, DensityFeatures(spin_restricted=True))
    dm_init = pyscf_scf.RKS(mol).get_init_guess()
    xc_params = xc_module.init(jax.random.PRNGKey(0), jnp.asarray(dm_init), sys.grid)

    def _solve_and_print(tag: str, P_ref, orbo, orbv, e_ia, hdiag):
        tda_mv = build_cassida_mv(
            sys,
            xc_module,
            xc_params,
            jnp.asarray(orbo),
            jnp.asarray(orbv),
            jnp.asarray(e_ia),
            jnp.asarray(P_ref),
            tda_approx=True,
        )
        tddft_mv = build_cassida_mv(
            sys,
            xc_module,
            xc_params,
            jnp.asarray(orbo),
            jnp.asarray(orbv),
            jnp.asarray(e_ia),
            jnp.asarray(P_ref),
            tda_approx=False,
        )

        print(f'- {tag} reference (SCAN/{basis_eff}, grid_level={grid_level_eff})')
        e_tda, _ = Davidson(
            lambda X: np.asarray(tda_mv(jnp.asarray(X))),
            hdiag,
            N_states=nstates,
            conv_tol=conv_tol,
        )
        print('  TDA energies (eV):', e_tda * Hartree_to_eV)

        def mv_xy(X: np.ndarray, Y: np.ndarray):
            U1, U2 = tddft_mv(jnp.asarray(X), jnp.asarray(Y))
            return np.asarray(U1), np.asarray(U2)

        e_tddft, _, _ = Davidson_Casida(
            mv_xy,
            hdiag,
            N_states=nstates,
            conv_tol=conv_tol,
        )
        print('  TDDFT energies (eV):', e_tddft * Hartree_to_eV)

    (
        P_ref,
        orbo,
        orbv,
        e_ia,
        hdiag,
        xc_params_custom,
        total_e,
    ) = _scan_reference_from_custom_scf(
        sys=sys,
        basis=basis_eff,
        cycles=cycles_eff,
        xc_module=xc_module,
    )
    print(f'- Custom SCF final E_HJ + E_XC (Ha): {total_e}')
    xc_params = xc_params_custom
    _solve_and_print('custom-SCF', P_ref, orbo, orbv, e_ia, hdiag)


def run_deixc(
    checkpoint_dir: str,
    molecule: str = 'water',
    nstates: int = 5,
    conv_tol: float = 1e-5,
    n_samples: int = 10,
    seed: int = 42,
    data_dir: str = 'ANONYMOUS_DIR',
    use_density_fitting: bool = False,
):
    """Run DEI-XC-ground-state-linearized TDA and TDDFT for a small demo molecule.

    Flow:
    1) Load checkpoint config and params.
    2) Build system (basis/grid/alignment) from the YAML.
    3) Run DEI-XC SCF with project solver to obtain reference (P_ref, F_ref).
    4) Solve generalized eigenproblem for orbital energies and MOs.
    5) Run Davidson solvers for TDA and full TDDFT using EGXC response operators.

    Args:
        checkpoint_dir: Directory containing `best_dynamic_train_params.flax` and one YAML.
        molecule: Name accepted by `egxc.systems.examples.get(...)` (e.g. "water"), or "qm5" for QM5 batch.
        nstates: Number of excitation energies to compute.
        conv_tol: Davidson residual tolerance.
        n_samples: Number of QM5 samples to process (only used if molecule=="qm5").
        seed: Random seed for sampling QM5 molecules (only used if molecule=="qm5").
        data_dir: Path to datasets directory (only used if molecule=="qm5").
    """
    Hartree_to_eV = 27.211385050
    ckp_dir = Path(checkpoint_dir)
    cfg = _load_checkpoint_config(ckp_dir)
    params_solver = unpickle_dictionary(str(ckp_dir / 'best_dynamic_train_params.flax'))
    xc_params = _extract_xc_module_params(params_solver)

    # build system from checkpoint config
    basis = cfg['basis']['name']
    grid_level = int(cfg['quadrature']['level'])
    a = cfg['alignment']
    alignment = Alignment(int(a['atom']), int(a['basis']), int(a['grid']))

    # Check if we're running on QM5 dataset
    if molecule.lower() == 'qm5':
        print('Loading QM5 dataset (heavy_atoms_thresh=5, exclude_fluorine=True)')
        dataset = QM9(
            data_dir=data_dir,
            heavy_atoms_thresh=4,
            exclude_fluorine=True,
        )

        total_samples = len(dataset)
        print(f'Total QM5 samples: {total_samples}')

        # Sample indices
        np.random.seed(seed)
        sample_indices = np.random.choice(
            total_samples, min(n_samples, total_samples), replace=False
        )
        print(f'Processing {len(sample_indices)} molecules with seed={seed}')

        # Process each molecule
        for count, idx in enumerate(sample_indices, 1):
            print(f'\n{"=" * 60}')
            print(f'[{count}/{len(sample_indices)}] Processing QM5 molecule {idx}')
            print(f'{"=" * 60}')

            # Load sample
            _, (nuc_pos, atom_z, _, _, _), _ = dataset[idx]
            nuc_pos = np.asarray(nuc_pos)
            atom_z = np.asarray(atom_z)

            print(f'Atoms: {len(atom_z)}, Elements: {np.unique(atom_z)}')

            # Build system
            sys = _build_system_from_sample(
                nuc_pos,
                atom_z,
                basis,
                grid_level,
                alignment,
                use_density_fitting=use_density_fitting,
            )
            n_electrons = int(sys.n_electrons)
            print(f'Electrons: {n_electrons}, Basis: {basis}')

            # Run TDDFT for this molecule
            _run_tddft_for_molecule(
                sys,
                basis,
                cfg,
                params_solver,
                xc_params,
                nstates,
                conv_tol,
                checkpoint_id='',
                molecule_idx=int(idx),
                use_density_fitting=use_density_fitting,
            )

        return

    # Single-molecule path
    sys = examples.get(
        molecule,
        basis=basis,
        alignment=alignment,
        use_density_fitting=bool(use_density_fitting),
        spin_restricted=True,
        include_grid=True,
        grid_level=grid_level,
    )

    # Run TDDFT for single molecule
    _run_tddft_for_molecule(
        sys,
        basis,
        cfg,
        params_solver,
        xc_params,
        nstates,
        conv_tol,
        checkpoint_id='',
        molecule_idx=-1,
        use_density_fitting=use_density_fitting,
    )


def _scan_reference_from_pyscf(
    *,
    mol,
    grid_level: int,
    xc: str,
):
    mf = dft.RKS(mol, xc=xc.lower())
    mf.grids.level = int(grid_level)
    mf.kernel()

    C = np.asarray(mf.mo_coeff)
    eps = np.asarray(mf.mo_energy)
    occ = np.asarray(mf.mo_occ)
    occidx = np.where(occ > 0)[0]
    viridx = np.where(occ == 0)[0]
    orbo, orbv = C[:, occidx], C[:, viridx]
    e_ia = eps[viridx] - eps[occidx, None]
    hdiag = e_ia.ravel()
    P_ref = np.asarray(mf.make_rdm1())
    return P_ref, orbo, orbv, e_ia, hdiag, mf


def run_compare_three(
    *,
    checkpoint_dir: str,
    molecule: str,
    nstates: int,
    conv_tol: float,
    scan_basis: str = 'def2-TZVPD',
    scan_grid_level: int | None = None,
):
    """Compare 3 end-to-end pipelines:

    1) DEI-XC checkpoint SCF + EGXC TDDFT operator
    2) SCAN (repo SCF) + EGXC TDDFT operator
    3) SCAN (PySCF SCF) + PySCF-native TDDFT
    """
    ckp_dir = Path(checkpoint_dir)
    cfg = _load_checkpoint_config(ckp_dir)
    grid_level_eff = (
        int(scan_grid_level)
        if scan_grid_level is not None
        else int(cfg['quadrature']['level'])
    )

    print('==============================')
    print('= (1) DEI-XC checkpoint (repo SCF + repo TDDFT operator)')
    print('==============================')
    run_deixc(
        checkpoint_dir=checkpoint_dir,
        molecule=molecule,
        nstates=nstates,
        conv_tol=conv_tol,
    )

    print('==============================')
    print('= (2) SCAN/def2-TZVPD (repo SCF + repo TDDFT operator)')
    print('==============================')
    run_scan_reference(
        checkpoint_dir=checkpoint_dir,
        molecule=molecule,
        nstates=nstates,
        conv_tol=conv_tol,
        basis=scan_basis,
        grid_level=grid_level_eff,
        cycles=None,
    )

    print('==============================')
    print('= (3) SCAN/def2-TZVPD (PySCF SCF + PySCF TDDFT)')
    print('==============================')
    # Build the same EGXC System solely to get a consistent PySCF `mol` object (basis/alignment).
    a = cfg['alignment']
    alignment = Alignment(int(a['atom']), int(a['basis']), int(a['grid']))
    sys = _build_system(
        molecule,
        basis=scan_basis,
        grid_level=grid_level_eff,
        alignment=alignment,
    )
    mol = sys.to_pyscf(scan_basis)
    mf = dft.RKS(mol, xc='scan')
    mf.grids.level = int(grid_level_eff)
    mf.kernel()

    print('- PySCF native excitations')
    tda_obj = pyscf_tddft.TDA(mf)
    tda_obj.nstates = nstates
    tda_obj.kernel()
    print('  TDA energies (eV):', np.asarray(tda_obj.e) * Hartree_to_eV)

    tddft_obj = pyscf_tddft.TDDFT(mf)
    tddft_obj.nstates = nstates
    tddft_obj.kernel()
    print('  TDDFT energies (eV):', np.asarray(tddft_obj.e) * Hartree_to_eV)
