import os
import argparse

import pandas as pd
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
# suppress rdkit warning
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

from sklearn.model_selection import train_test_split


# refer:
# https://github.com/DeepGraphLearning/torchdrug/blob/9fac9129fd7e674c23cc6ee17abb2f4dab319467/torchdrug/utils/file.py#L10
def download(url, path, save_file=None, md5=None):
    """
    Download a file from the specified url.
    Skip the downloading step if there exists a file satisfying the given MD5.
    Parameters:
        url (str): URL to download
        path (str): path to store the downloaded file
        save_file (str, optional): name of save file. If not specified, infer the file name from the URL.
        md5 (str, optional): MD5 of the file
    """
    from six.moves.urllib.request import urlretrieve

    if not os.path.exists(path):
        os.makedirs(path)

    if save_file is None:
        save_file = os.path.basename(url)
        if "?" in save_file:
            save_file = save_file[:save_file.find("?")]
    save_file = os.path.join(path, save_file)

    if not os.path.exists(save_file) or compute_md5(save_file) != md5:
        print("Downloading %s to %s" % (url, save_file))
        urlretrieve(url, save_file)
    return save_file

def compute_md5(file_name, chunk_size=65536):
    """
    Compute MD5 of the file.
    Parameters:
        file_name (str): file name
        chunk_size (int, optional): chunk size for reading large files
    """
    import hashlib

    md5 = hashlib.md5()
    with open(file_name, "rb") as fin:
        chunk = fin.read(chunk_size)
        while chunk:
            md5.update(chunk)
            chunk = fin.read(chunk_size)
    return md5.hexdigest()

KEEP_ATOM = ['C', 'H', 'O', 'N', 'F', 'S', 'Cl', 'P', 'B', 'Br', 'I']



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Preprocess the Data')
    parser.add_argument('--dataset', type=str, default = '',
                        help='dataset type')
    parser.add_argument('--path', type=str, default = '',
                        help='path to output data')
    args = parser.parse_args()



    # 0. download the dataset
    if args.dataset == 'bbbp':
        url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/BBBP.csv"
        md5 = "66286cb9e6b148bd75d80c870df580fb"
        
    elif args.dataset == 'tox21':
        url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/tox21.csv.gz"
        md5 = "2882d69e70bba0fec14995f26787cc25"
        
    elif args.dataset == 'toxcast':
        url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/toxcast_data.csv.gz"
        md5 = "92911bbf9c1e2ad85231014859388cd6"

    elif args.dataset == 'sider':
        url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/sider.csv.gz"
        md5 = "77c0ef421f7cc8ce963c5836c8761fd2"
    
    elif args.dataset == 'clintox':
        url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/clintox.csv.gz"
        md5 = "db4f2df08be8ae92814e9d6a2d015284"

    elif args.dataset == 'muv':
        url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/muv.csv.gz"
        md5 = "9c40bd41310991efd40f4d4868fa3ddf"

    elif args.dataset == 'hiv':
        url = "http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv"
        md5 = "9ad10c88f82f1dac7eb5c52b668c30a7"

    else:
        raise ValueError("Undifined dataset. ")

    file_name = download(url, args.path, md5=md5)



    # 1. preprocess the dataset
    if args.dataset == 'bbbp' or args.dataset == 'hiv':
        df = pd.read_csv(file_name)
    elif args.dataset == 'tox21' or args.dataset == 'toxcast' or \
            args.dataset == 'sider' or args.dataset == 'clintox' or \
            args.dataset == 'muv': 
        df = pd.read_csv(file_name, compression='gzip')

    # columns = df.columns.values.tolist()
    # print('# Task: {}\n{}'.format(len(columns), columns))

    out_data = []
    for idx, row in df.iterrows(): 
        smiles = str(row['smiles'])
        mol = Chem.MolFromSmiles(smiles)
        if mol is None: continue
        if len(mol.GetAtoms()) == 1: continue
        
        mol = Chem.AddHs(mol) 
        if mol is None: continue
        
        if mol.GetNumAtoms() > 300:
            continue

        is_compound_countain_rare_atom = False
        for i in range(mol.GetNumAtoms()):
            a = mol.GetAtomWithIdx(i).GetSymbol()
            if a not in KEEP_ATOM:
                is_compound_countain_rare_atom = True
                break
        if is_compound_countain_rare_atom:
            continue

        out_data.append(row)

    new_df = pd.DataFrame(out_data)



    # 2. split into train and test
    output_train_path = os.path.join(args.path, args.dataset+'_train.csv')
    output_test_path = os.path.join(args.path, args.dataset+'_test.csv')
    
    # random splitting
    # print('random splitting...')
    # train, test = train_test_split(new_df, test_size=0.1)

    # scaffold splitting
    print('scaffold splitting...')
    new_df['scaffold'] = new_df['smiles'].apply(lambda x: MurckoScaffold.MurckoScaffoldSmiles(smiles=x, includeChirality=False))
    grouped_new_df = new_df.sort_values(by=['scaffold']).groupby('scaffold')
    cutoff = int(len(new_df)*0.1)
    modified_cutoff = 0
    for name, group in grouped_new_df:
        if modified_cutoff < cutoff:
            modified_cutoff += len(group)
    test = new_df[:cutoff]
    train = new_df[cutoff:]
    # print('Whole dataset: # Positive = {}, # All = {}'.format(new_df['p_np'].sum(), len(new_df)))
    # print('Train subset: # Positive = {}, # All = {}'.format(train['p_np'].sum(), len(train)))
    # print('Test subset: # Positive = {}, # All = {}'.format(test['p_np'].sum(), len(test)))

    train.to_csv(output_train_path, index=False)
    test.to_csv(output_test_path, index=False)
    print('Save {} training data to {}, {} test data to {}.'.format(len(train), output_train_path, len(test), output_test_path))

