'''
Date: 2021-11-20 22:52:23
LastEditors: yuhhong
LastEditTime: 2022-08-06 15:01:47
'''
import pandas as pd
import numpy as np

from rdkit import Chem
# suppress rdkit warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

import argparse

KEEP_ATOM = ['C', 'H', 'O', 'N', 'F', 'S', 'Cl', 'P', 'B', 'Br', 'I']
ADDs = {'[M-H2O-H]-': -4, '[M+Na-2H]-': -3, '[M+HCOO]-': -2, '[M-H]-': -1, '[M+H]+': 1, '[M+K]+': 2, '[M+H-H2O]+': 3, '[M+K-2H]-': 4, '[M-H+2Na]+': 5, '[M+NH4]+': 6, '[M+CH3COO]-': 7, '[M+Na]+': 8, '[M-H+2K]+': 9, '[M-2H+3Na]+': 10, 'M+': 11}



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preprocess the Data')
    parser.add_argument('--input', type=str, default = '',
                        help='path to input data')
    parser.add_argument('--output', type=str, default = '',
                        help='path to output data')
    args = parser.parse_args()

    df = pd.read_csv(args.input)
    print(df)

    # filter by 'Type'
    new_df = df[df['Type']=='Experimental CCS'].copy()

    # filter by 'Adduct'
    new_df['AdductEncode'] = new_df['Adduct'].apply(lambda x: ADDs[x])

    # output the data
    data = []
    # atom_num = []
    for idx, row in new_df.iterrows():
        smiles = str(row['Structure'])
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
        
        # atom_num.append(mol.GetNumAtoms())
        is_compound_countain_rare_atom = False
        for i in range(mol.GetNumAtoms()):
            a = mol.GetAtomWithIdx(i).GetSymbol()
            if a not in KEEP_ATOM:
                is_compound_countain_rare_atom = True
                break
        if is_compound_countain_rare_atom:
            continue

        data.append(row)

    new_df = pd.DataFrame(data)
    new_df = new_df.rename(columns={'Structure': 'SMILES', 'AllCCS ID': 'ID'})
    new_df.to_csv(args.output, index=False)

    # print(np.mean(np.array(atom_num)))
    # print(np.max(np.array(atom_num)))