'''
Date: 2022-07-12 18:14:40
LastEditors: yuhhong
LastEditTime: 2022-08-06 15:35:42
'''
import pandas as pd
from urllib.request import urlopen
from urllib.parse import quote

from rdkit import Chem
# suppress rdkit warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

import argparse

def CIRconvert(ids):
    try:
        url = 'http://cactus.nci.nih.gov/chemical/structure/' + quote(ids) + '/smiles'
        ans = urlopen(url).read().decode('utf8')
        # print(ans)
        return ans
    except:
        return None

KEEP_ATOM = ['C', 'H', 'O', 'N', 'F', 'S', 'Cl', 'P', 'B', 'Br', 'I']
ADDs = {'[M-H]+': -1, '[M+H]+': 1, '[M+Na]+': 8, '[M]+': 11, 'M+': 11, '[M+H-H2O]+': 3}



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preprocess the Data')
    parser.add_argument('--input', type=str, default = '',
                        help='path to input data')
    parser.add_argument('--allccs', type=str, default = '',
                        help='path to allccs data')
    parser.add_argument('--output', type=str, default = '',
                        help='path to output data')
    args = parser.parse_args()

    df = pd.read_excel(args.input, sheet_name='MicroSource Collection', usecols=['Compound', 'CAS', 'Adduct', 'Ω(N2) / Å^2'])
    print(df)

    df['SMILES'] = df['Compound'].apply(CIRconvert)
    print(df)
    df.dropna(inplace=True)
    df.reset_index(inplace=True)
    df = df.rename(columns={'index':'ID', "Compound": "Name", "Ω(N2) / Å^2": "CCS"})

    # filter out
    data = []
    for idx, row in df.iterrows():
        smiles = str(row['SMILES'])
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
        
        # atom_num.append(mol.GetNumAtoms())
        is_compound_countain_rare_atom = False
        for i in range(mol.GetNumAtoms()):
            a = mol.GetAtomWithIdx(i).GetSymbol()
            if a not in KEEP_ATOM:
                is_compound_countain_rare_atom = True
                break
        if is_compound_countain_rare_atom:
            continue

        if row['Adduct'] not in list(ADDs.keys()):
            continue

        data.append(row)

    df = pd.DataFrame(data)
    df['AdductEncode'] = df['Adduct'].apply(lambda x: ADDs[x])

    # select the test set, which is not showing up in training set
    train_list = pd.read_csv(args.allccs, usecols=['SMILES']).values.tolist()
    df['Overlap'] = df['SMILES'].apply(lambda x: x in train_list)
    print(df)
    test_df = df[df['Overlap']==False]
    test_df = test_df.drop(columns=['Overlap'])
    print(test_df)

    test_df.to_csv(args.output, index=False)
    print('Save {} data to {}}'.format(len(test_df), args.output))