'''
Date: 2022-01-24 17:32:22
LastEditors: yuhhong
LastEditTime: 2022-07-23 16:09:04
'''
from rdkit import Chem
# suppress rdkit warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

import argparse

'''
preprocess: 
  1. remove the invalid 3d conformer
  2. remove the molecules with unlabeled atoms 
      (we only label the atoms: ['C', 'H', 'O', 'N', 'F', 'S', 'Cl', 'P', 'B', 'Br', 'I'])
  3. remove the molecules with more than 200 atom 
      (it can be changed, but how about using 200 temporarily)
'''

ATOM_LIST = ['C', 'H', 'O', 'N', 'F', 'S', 'Cl', 'P', 'B', 'Br', 'I']

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preprocess the Data')
    parser.add_argument('--input', type=str, default = '',
                        help='path to input data')
    parser.add_argument('--output', type=str, default = '',
                        help='path to output data')
    args = parser.parse_args()

    supp = Chem.SDMolSupplier(args.input)
    out_mols = []
    print('Get {} data from {}'.format(len(supp), args.input))
    for idx, mol in enumerate(supp):
        if mol is None:
            continue

        # remove invalid molecular blocks
        # e.x.
        # 22      RDKit          2D
        # 0
        # 39  25 28  0  0  0  0  0  0  0  0999 V2000
        # 69 10001.078110001.2576    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0
        # 69  9998.940610000.0077    0.0000 C   0  0  0  0  0  0  0  0  0  0  0  0
        # 69  9996.7988 9997.9394    0.0000 O   0  0  0  0  0  0  0  0  0  0  0  0
        # ......
        # 12   4  5  1  0
        # ......
        mol_block = Chem.MolToMolBlock(mol).split("\n")
        mol_block_length = sum([1 for d in mol_block if len(d)==69 and len(d.split())==16])
        if mol_block_length < mol.GetNumAtoms(): 
            print(mol_block_length, '<', mol.GetNumAtoms())
            continue
        if mol.GetNumAtoms() > 120: # --num_atoms 120
            print('Too many atoms')
            continue

        flag_remove = False
        for atom in mol.GetAtoms():
            if atom.GetSymbol() not in ATOM_LIST:
                flag_remove = True
                break
        if flag_remove:
            print('Unlabeled atom')
            continue

        out_mols.append(mol)

    print('Writing {} data to {}'.format(len(out_mols), args.output))
    w = Chem.SDWriter(args.output)
    for m in out_mols:
        w.write(m)
    print('Done!')