import psi4
from rdkit.Chem.rdDistGeom import ETKDGv2, EmbedMolecule
from rdkit.Chem.rdForceFieldHelpers import MMFFHasAllMoleculeParams, MMFFOptimizeMolecule
from rdkit import Chem
import math
from utils.mol_utils import smiles_to_mols

def find_metric (smiles, target):
    mols_props = []

    # Hardware side settings (CPU thread number and memory settings used for calculation)
    psi4.set_num_threads(nthread=1)
    psi4.set_memory("64GB")
    psi4.core.set_output_file('psi4_output.dat', False)
    mols = smiles_to_mols (smiles)

    for mol in mols:
        # Coarse 3D structure optimization by generating 3D structure from SMILES
        mol = Chem.AddHs(mol)
        params = ETKDGv2()
        params.randomSeed = 1
        try:
            EmbedMolecule(mol, params)
        except Chem.rdchem.AtomValenceException:
            print('invalid chemistry')
            continue

        # Structural optimization with MMFF (Merck Molecular Force Field)
        try:
            s = MMFFOptimizeMolecule(mol)
            print(s)
        except:
            print('Bad conformer ID')
            continue

        conf = mol.GetConformer()

        # Convert to a format that can be input to Psi4.
        # Set charge and spin multiplicity (below is charge 0, spin multiplicity 1)

        # Get the formal charge
        fc = 'FormalCharge'
        mol_FormalCharge = int(mol.GetProp(fc)) if mol.HasProp(fc) else Chem.GetFormalCharge(mol)

        sm = 'SpinMultiplicity'
        if mol.HasProp(sm):
            mol_spin_multiplicity = int(mol.GetProp(sm))
        else:
            # Calculate spin multiplicity using Hund's rule of maximum multiplicity...
            NumRadicalElectrons = 0
            for Atom in mol.GetAtoms():
                NumRadicalElectrons += Atom.GetNumRadicalElectrons()
            TotalElectronicSpin = NumRadicalElectrons / 2
            SpinMultiplicity = 2 * TotalElectronicSpin + 1
            mol_spin_multiplicity = int(SpinMultiplicity)

        mol_input = "%s %s" % (mol_FormalCharge, mol_spin_multiplicity)
        #mol_input = "0 1"

        # Describe the coordinates of each atom in XYZ format
        for atom in mol.GetAtoms():
            mol_input += "\n " + atom.GetSymbol() + " " + str(conf.GetAtomPosition(atom.GetIdx()).x) \
                            + " " + str(conf.GetAtomPosition(atom.GetIdx()).y) \
                            + " " + str(conf.GetAtomPosition(atom.GetIdx()).z)

        try:
            molecule = psi4.geometry(mol_input)
        except:
            print('Can not calculate psi4 geometry')
            continue

        # Convert to a format that can be input to pyscf
        # Set calculation method (functional) and basis set
        # level = "b3lyp/6-31G*"
        level = "CCSD/cc-pV[DT]Z"

        # Calculation method (functional), example of basis set
        # theory = ['hf', 'b3lyp']
        # basis_set = ['sto-3g', '3-21G', '6-31G(d)', '6-31+G(d,p)', '6-311++G(2d,p)']

        # Perform structural optimization calculations
        print('Psi4 calculation starts!!!')
        #energy, wave_function = psi4.optimize(level, molecule=molecule, return_wfn=True)
        try:
            energy, wave_function = psi4.energy(level, molecule=molecule, return_wfn=True)
        except Exception as e:
            print (e)
            continue
        # except psi4.driver.SCFConvergenceError:
        #     print("Psi4 did not converge")
        #     continue

        print('Chemistry information check!!!')

        if target == 'mu':
            dip_x, dip_y, dip_z = wave_function.variable('SCF DIPOLE')[0],\
                                    wave_function.variable('SCF DIPOLE')[1],\
                                    wave_function.variable('SCF DIPOLE')[2]
            dipole_moment = math.sqrt(dip_x**2 + dip_y**2 + dip_z**2) * 2.5417464519
            print("Dipole moment", dipole_moment)
            mols_props.append(dipole_moment)

        if target == 'homo':
            # Compute HOMO (Unit: au= Hartree）
            LUMO_idx = wave_function.nalpha()
            HOMO_idx = LUMO_idx - 1
            homo = wave_function.epsilon_a_subset("AO", "ALL").np[HOMO_idx]

            # convert unit from a.u. to ev
            homo = homo * 27.211324570273
            print("HOMO", homo)
            mols_props.append(homo)

    return mols_props

