import jax
import numpy as np
import pytest
from pyscf import dft
from utils import assert_either_abs_or_rel_close, assert_is_close, set_jax_testing_config

from deixc.data_generation import pyscf_based
from deixc.data_generation.custom import compute_sample_targets_fn_factory
from deixc.orbital_transforms import ao_to_mo, mo_to_ao
from deixc.scf import DerivativeInformedSelfConsistentFieldSolver
from egxc.systems import examples
from egxc.training.loss.density import get_density_mean_field_error_fn
from egxc.utils.linalg import coeffs_to_density_matrix
from egxc.xc_energy import DensityFeatures, XCModule, get_functional

set_jax_testing_config()
pytestmark = [pytest.mark.quick, pytest.mark.data]

# BASIS = 'def2-svp'
BASIS = 'sto-3g'
XC = 'B3LYP'
USE_DENSITY_FITTING = False

SYSTEM = examples.get(
    'water', basis=BASIS, use_density_fitting=USE_DENSITY_FITTING, alignment=1
)
MOL, MF = pyscf_based.get_pyscf_mol_and_ks_meanfield(
    SYSTEM.atom_z,
    SYSTEM._nuc_pos,
    BASIS,
    spin=0,
    charge=0,
    xc_str=XC,
    use_eri_density_fitting=USE_DENSITY_FITTING,
    use_exchange_density_fitting=USE_DENSITY_FITTING,
    spin_restricted=True,
    quadrature_grid_level=1,
)
TARGETS = pyscf_based.compute_sample_targets(
    MOL, MF, with_forces=True, with_d3_correction=True
)

density_mean_field_error_fn = get_density_mean_field_error_fn(
    spin_restricted=True,
    use_density_fitting=USE_DENSITY_FITTING,
    scale_per_electron=False,
)


def test_convergence():
    mf = dft.RKS(MOL, xc=XC)
    cycle_count = 0

    def callback(envs):
        nonlocal cycle_count
        cycle_count += 1
        print(f'SCF cycle {cycle_count}')

    mf.callback = callback  # type: ignore
    P = np.asarray(coeffs_to_density_matrix(TARGETS.mo_coeffs[-1], TARGETS.occupancies))
    mf.kernel(dm0=P)
    assert cycle_count < 3, (
        f'Ran {cycle_count} SCF cycles, but there should be no additional SCF cycles required when starting from the converged density'
    )
    density_error = density_mean_field_error_fn(SYSTEM, P, mf.make_rdm1()) * 1e3  # type: ignore
    assert density_error < 5e-3, (
        f'Density error is {density_error:.3f} mHa, which is too large'
    )
    energy_error = abs(mf.e_tot - TARGETS.total_energies[-1]) * 1e3
    assert energy_error < 5e-2, (
        f'Total energy mismatch: difference {energy_error:.3f} mHa'
    )


def test_xc_potential():
    """
    The ways of pyscf are many, and most of them are wrong.
    Since you cannot trust it you have to test it.
    -- A wise software engineer using pyscf

    Note that
    `_, _, v_xc = mf._numint.get_vxc(mol, mf.grids, mf.xc, dms=P)`
    only returns the semi-local contribution to the XC potential.
    """
    mf = dft.RKS(MOL, xc=XC)
    P = np.asarray(coeffs_to_density_matrix(TARGETS.mo_coeffs[-1], TARGETS.occupancies))
    F = mf.get_fock(dm=P)
    J = mf.get_j(dm=P)
    H_core = mf.get_hcore()
    V_xc = F - J - H_core
    assert_either_abs_or_rel_close(
        TARGETS.xc_potential_matrices[-1],
        V_xc,
        name='consistency of v_xc',
        relative_tolerance=2e-5,
    )


def test_orbital_transforms():
    C = TARGETS.mo_coeffs[-1]
    S = MOL.intor('int1e_ovlp')

    # round-trip: MO (lower dim) → AO (full dim) → MO (lower dim)
    n_virt = TARGETS.n_basis_functions - TARGETS.n_occ
    mo = np.random.randn(TARGETS.n_occ, n_virt)
    ao = mo_to_ao(mo, C, S, TARGETS.n_occ)
    mo_back = ao_to_mo(ao, C, TARGETS.n_occ)

    assert_is_close(
        mo,
        mo_back,
        name='ao_to_mo(mo_to_ao(.)) == identity on (O,V)',
    )


def test_custom_pipeline_consistency():
    functional = get_functional(
        XC,
        spin_restricted=True,
        use_density_fitting=USE_DENSITY_FITTING,
    )
    xc_module = XCModule(functional, DensityFeatures(spin_restricted=True))
    scf_solver = DerivativeInformedSelfConsistentFieldSolver(
        xc_module,
        spin_restricted=True,
        cycles=15,
        use_density_fitting=USE_DENSITY_FITTING,
    )
    scf_solver.init(jax.random.PRNGKey(0), SYSTEM.fock_tensors.overlap, SYSTEM)
    compute_targets = compute_sample_targets_fn_factory(scf_solver)
    custom_targets = compute_targets(MF.get_init_guess(), SYSTEM)  # type: ignore

    total_energy_error = (
        abs(custom_targets.total_energies[-1] - TARGETS.total_energies[-1]) * 1e3
    )
    assert total_energy_error < 1e-6, (
        f'Total energy error is {total_energy_error:.6f} mHa, which is too large'
    )
    xc_energy_error = abs(custom_targets.xc_energies[-1] - TARGETS.xc_energies[-1]) * 1e3
    assert (
        xc_energy_error < 2e-3  # mHa
    ), (  # without diis regularization, deviation is much lower: 2e-8 Hartree
        f'XC energy error is {xc_energy_error:.6f} mHa, which is too large'
    )
    orbital_energy_error = (
        np.max(abs(custom_targets.orbital_energies - TARGETS.orbital_energies)) * 1e3
    )
    assert (
        orbital_energy_error < 6e-3  # mHa
    ), (  # without diis regularization, deviation is much lower: 3e-8 Hartree
        f'Orbital energy error is {orbital_energy_error:.6f} mHa, which is too large'
    )

    custom_density = coeffs_to_density_matrix(
        custom_targets.mo_coeffs[-1], custom_targets.occupancies
    )
    pyscf_density = coeffs_to_density_matrix(TARGETS.mo_coeffs[-1], TARGETS.occupancies)

    density_error = (
        density_mean_field_error_fn(SYSTEM, custom_density, pyscf_density) * 1e3
    )
    assert (
        density_error < 2e-3  # mHa
    ), (  # without diis regularization, deviation is much lower: 2e-8 Hartree
        f'Density error is {density_error:.6f} mHa, which is too large'
    )
