import numpy as np
import pyscf
import pyscf.df

from dataset import MultipartLMDBDataset

BASIS = 'def2-svp'
AUXBASIS = 'def2-universal-jfit'

def get_rho_auxbasis_denfit(features):
    nums = features['atom_number']
    coords = features['atom_coords']
    charge = int(features['net_charge'])
    spin = int(features['spin'])
    dm = features['density_matrix']

    mol = pyscf.M(atom=list(zip(nums.tolist(), coords.tolist())),
                  unit='angstrom', basis=BASIS, charge=charge, spin=spin)

    auxmol = pyscf.df.addons.make_auxmol(mol, auxbasis=AUXBASIS)
    ints_3c1e = pyscf.df.incore.aux_e2(mol, auxmol, intor="int3c1e")
    aux_vec = np.linalg.solve(auxmol.intor('int1e_ovlp'), np.einsum("ij,ijp->p", dm, ints_3c1e))

    return {'aux_density_denfit': aux_vec}


def main():
    dataset = MultipartLMDBDataset('dataset/main', parts_to_load=['base', 'dm', 'auxdensity.denfit'])
    d = dataset[0]
    density_coeffs_dict = get_rho_auxbasis_denfit(d)
    print(density_coeffs_dict['aux_density_denfit'])


if __name__ == '__main__':
    main()
