"""TDDFT/TDA regression test comparing EGXC HVP builders against PySCF."""

import os
from typing import Any, Tuple

os.environ.setdefault('JAX_PLATFORMS', 'cpu')  # Force CPU to avoid CUDA plugin noise.

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from pyscf import dft
from pyscf import tddft as _tddft
from utils import assert_is_close, set_jax_testing_config

from egxc.systems import examples
from egxc.xc_energy import DensityFeatures
from egxc.xc_energy.functionals import GGA, Hybrid
from egxc.xc_energy.functionals.classical import BaseRangeSeparatedHybrid
from egxc.xc_energy.xc_module import XCModule
from tddft import build_cassida_mv, davidson, davidson_casida

set_jax_testing_config()

BASIS = 'def2-SVP'
GRID_LEVEL = 3
NSTATES = 5
ENERGY_TOL_MHA = 50.0  # loose to avoid flakiness across platforms


def _setup_system(
    molecule: str, basis: str, grid_level: int, xc_functional_type: str
) -> Tuple[Any, Any, Any]:
    """Configure system, XC module, and parameters for the test molecule."""
    system = examples.get(
        molecule,
        basis=basis,
        grid_level=grid_level,
        spin_restricted=True,
        use_density_fitting=False,
    )

    mol = system.to_pyscf(basis)
    from pyscf import scf as pyscf_scf

    mf_init = pyscf_scf.RKS(mol)
    initial_dm = mf_init.get_init_guess()

    if xc_functional_type.lower() == 'b3lyp':
        xc_functional = Hybrid(
            key='b3lyp', use_density_fitting=False, spin_restricted=True
        )
    elif xc_functional_type.lower() == 'pbe':
        xc_functional = GGA(key='pbe')
    else:
        raise ValueError(f'Unsupported functional {xc_functional_type}')

    density_features = DensityFeatures(spin_restricted=True)
    xc_module = XCModule(xc_functional, density_features)

    eri_kwargs = (
        {'eri_tensor': system.fock_tensors.electron_repulsion_tensor}
        if isinstance(xc_functional, (Hybrid, BaseRangeSeparatedHybrid))
        else {}
    )
    params = xc_module.init(
        jax.random.PRNGKey(0), jnp.array(initial_dm), system.grid, **eri_kwargs
    )
    return system, xc_module, params


@pytest.mark.slow
@pytest.mark.parametrize('xc_functional_type', ['pbe', 'b3lyp'], ids=['pbe', 'b3lyp'])
def test_tddft_matches_pyscf(xc_functional_type: str):
    """Ensure EGXC TDDFT/TDA excitations align with PySCF reference."""
    system, xc_module, params = _setup_system(
        'water', basis=BASIS, grid_level=GRID_LEVEL, xc_functional_type=xc_functional_type
    )

    mol = system.to_pyscf(BASIS)
    mf = dft.RKS(mol, xc=xc_functional_type.lower())
    mf.grids.level = GRID_LEVEL
    mf.kernel()

    mo_coeff = np.asarray(mf.mo_coeff)
    mo_energy = np.asarray(mf.mo_energy)
    mo_occ = np.asarray(mf.mo_occ)
    occidx = np.where(mo_occ == 2)[0]
    viridx = np.where(mo_occ == 0)[0]
    C_o = mo_coeff[:, occidx]
    C_v = mo_coeff[:, viridx]
    e_ia = mo_energy[viridx] - mo_energy[occidx, None]
    hdiag = e_ia.ravel()
    P_ref = mf.make_rdm1()

    tddft_mv_egxc = build_cassida_mv(
        system,
        xc_module,
        params,
        jnp.asarray(C_o),
        jnp.asarray(C_v),
        jnp.asarray(e_ia),
        jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=False,
        tda_approx=False,
    )
    tda_mv_egxc = build_cassida_mv(
        system,
        xc_module,
        params,
        jnp.asarray(C_o),
        jnp.asarray(C_v),
        jnp.asarray(e_ia),
        jnp.asarray(P_ref),
        spin_restricted=True,
        use_density_fitting=False,
        tda_approx=True,
    )

    TDA_obj = _tddft.TDA(mf)
    TDA_obj.nstates = NSTATES
    TDA_obj.kernel()
    e_tda_pyscf = np.asarray(TDA_obj.e)

    TDDFT_obj = _tddft.TDDFT(mf)
    TDDFT_obj.nstates = NSTATES
    TDDFT_obj.kernel()
    e_tddft_pyscf = np.asarray(TDDFT_obj.e)

    def _tda_mv_numpy(X: np.ndarray) -> np.ndarray:
        return np.asarray(tda_mv_egxc(jnp.asarray(X)))

    e_tda_egxc, _ = davidson(
        _tda_mv_numpy, hdiag, n_states=NSTATES, conv_tol=1e-5, max_iter=40
    )

    def _tddft_mv_numpy(X: np.ndarray, Y: np.ndarray):
        U1, U2 = tddft_mv_egxc(jnp.asarray(X), jnp.asarray(Y))
        return np.asarray(U1), np.asarray(U2)

    e_tddft_egxc, _, _ = davidson_casida(
        _tddft_mv_numpy, hdiag, n_states=NSTATES, conv_tol=1e-5, max_iter=60
    )

    assert_is_close(
        e_tda_egxc,
        e_tda_pyscf,
        tolerance=ENERGY_TOL_MHA / 1000,
        absolute=True,
        name='tda energies (Ha)',
    )
    assert_is_close(
        e_tddft_egxc,
        e_tddft_pyscf,
        tolerance=ENERGY_TOL_MHA / 1000,
        absolute=True,
        name='tddft energies (Ha)',
    )
