import numpy as onp
from pyscf import gto, scf
from pyscf.dispersion import dftd3

from deixc.data_generation.hessian import linear_response
from deixc.data_generation.utils import BaseDftDeixcTargetGenerator, DeiXCTargets, Timer
from egxc.data_generation.generator import get_pyscf_mol_and_ks_meanfield
from egxc.dataloading import io


def compute_sample_targets(
    mol: gto.Mole, mf: scf.hf.SCF, with_forces: bool, with_d3_correction: bool
) -> DeiXCTargets:
    """
    Compute DEI-XC targets for a single sample.

    Parameters
    ----------
    mol: pyscf.gto.Mole
        Molecule.
    mf: A pyscf Restricted (RKS) or Unrestricted (UKS) Kohn-Sham mean-field object.
    """
    # logging of scf trajectory
    densities = []
    mo_coeffs = []
    fock_matrices = []
    total_energies = []
    xc_energies = []
    xc_potential_matrices = []

    def callback(envs):
        densities.append(envs['dm'])  # type: ignore
        mo_coeffs.append(envs['mo_coeff'])  # type: ignore
        fock_matrices.append(envs['fock'])  # type: ignore
        total_energies.append(envs['e_tot'])  # type: ignore

        # pure XC energy and potential:
        vhf = envs['vhf']  # tagged array J + V_xc
        xc_energies.append(  # type: ignore
            vhf.exc
        )  # exchange–correlation energy
        pure_vxc = vhf - vhf.vj  # subtract out the Coulomb J  # type: ignore
        xc_potential_matrices.append(pure_vxc)  # type: ignore

    mf.callback = callback  # type: ignore
    # main computation
    timer = Timer().set()
    mf.kernel()  # type: ignore
    timer.log('scf_calculation')

    mo_coeffs = onp.array(mo_coeffs)
    densities = onp.array(densities)
    fock_matrices = onp.array(fock_matrices)
    total_energies = onp.array(total_energies)
    xc_energies = onp.array(xc_energies)
    xc_potential_matrices = onp.array(xc_potential_matrices)
    orbital_energies = onp.diag(mo_coeffs[-1].T @ fock_matrices[-1] @ mo_coeffs[-1])

    timer.set()
    xc_minimization_direction, linear_responses_xc_pot, hessian_diagonal = (
        linear_response.get(
            mf,  # type: ignore
            mo_coeffs,
            fock_matrices,
            xc_potential_matrices,
        )
    )
    timer.log('linear_response_from_ground_state')
    if with_forces:
        timer.set()
        forces = mf.nuc_grad_method().kernel()  # TODO: Should we include grid_response?
        timer.log('pyscf_dft_forces')
    else:
        forces = None

    # D3 correction
    if with_d3_correction:
        timer.set()
        disp = dftd3.DFTD3Dispersion(mol, xc='B97M', version='d3bj')
        d3_dispersion = disp.get_dispersion(grad=True)
        timer.log('d3_dispersion')
    else:
        d3_dispersion = {'energy': None, 'gradient': None}

    out = DeiXCTargets(
        mo_coeffs=mo_coeffs,
        total_energies=total_energies,
        xc_energies=xc_energies,
        xc_potential_matrices=xc_potential_matrices,
        linear_response_xc_pot=linear_responses_xc_pot,
        orbital_energies=orbital_energies,
        density_hessian_diagonal=hessian_diagonal,
        forces=forces,
        d3_dispersion_energy=d3_dispersion['energy'],  # type: ignore
        d3_dispersion_forces=d3_dispersion['gradient'],  # type: ignore
        compute_costs=timer.get_timings(),
    )
    out.test_recomputations(xc_minimization_direction, densities, xc_potential_matrices)
    return out


class PyscfDftTargetGenerator(BaseDftDeixcTargetGenerator):
    aux_data_key = 'deixc'
    backend = 'pyscf'
    include_forces = False
    include_d3_correction = False

    def __call__(self, chunk_id: int, n_chunks: int) -> None:
        start, end = self.get_chunk_indices(chunk_id, n_chunks)
        for i in range(start, end):
            sample = self.dataset[i]
            idx, (nuc_pos, atom_z, charge, spin, _), _ = sample
            if io.auxiliary_data_exists(self.aux_dir, idx):
                print(f'Skipping {idx}, already exists in {self.aux_dir}', flush=True)
                continue
            print(f'Recomputing sample {idx}...', flush=True)
            mol, mf = get_pyscf_mol_and_ks_meanfield(
                atom_z,
                nuc_pos,
                self.kwargs['basis'],
                spin,
                charge,
                self.kwargs['xc_str'],
                self.kwargs['use_eri_density_fitting'],
                self.kwargs['use_exchange_density_fitting'],
                self.kwargs['spin_restricted'],
                self.kwargs['quadrature_grid_level'],
            )
            targets = compute_sample_targets(
                mol, mf, self.include_forces, self.include_d3_correction
            )
            self.save_deixc_targets(int(idx), targets)
