# run this script from terminal in order to download and process smiles files.

import logging
logging.getLogger().setLevel(logging.INFO)
import os
from os.path import join as join
import urllib
import urllib.request
import argparse

from qm9.data.prepare.qm9_synthetic_coordinates import smiles_to_xyz
from qm9.data.prepare.process_synthetic_coordinates import process_dataset
from qm9.data.prepare.prepare_dataset_info import compute_dataset_info


def prepare_dataset(datadir, dataset, subset=None, splits=None, cleanup=True, force_download=False, use_vocab_data=False):
    """
    Download and process dataset.

    Parameters
    ----------
    datadir : str
        Path to the directory where the data and calculations and is, or will be, stored.
    dataset : str
        String specification of the dataset.  If it is not already downloaded, must currently by "qm9" or "md17".
    subset : str, optional
        Which subset of a dataset to use.  Action is dependent on the dataset given.
        Must be specified if the dataset has subsets (i.e. MD17).  Otherwise ignored (i.e. GDB9).
    splits : dict, optional
        Dataset splits to use.
    cleanup : bool, optional
        Clean up files created while preparing the data.
    force_download : bool, optional
        If true, forces a fresh download of the dataset.

    Returns
    -------
    datafiles : dict of strings
        Dictionary of strings pointing to the files containing the data. 

    Notes
    -----
    TODO: Delete the splits argument?
    """

    # If datasets have subsets,
    if subset:
        dataset_dir = [datadir, dataset, subset]
    else:
        dataset_dir = [datadir, dataset]

    # Names of splits, based upon keys if split dictionary exists, elsewise default to train/valid/test.
    split_names = splits.keys() if splits is not None else [
        'train', 'valid', 'test']

    if use_vocab_data:
        split_names.append('vocab')

    # Assume one data file for each split
    datafiles = {split: os.path.join(
        *(dataset_dir + [split + '.npz'])) for split in split_names}

    # Check datafiles exist
    datafiles_checks = [os.path.exists(datafile)
                        for datafile in datafiles.values()]

    # Check if prepared dataset exists, and if not set flag to download below.
    # Probably should add more consistency checks, such as number of datapoints, etc...
    new_download = False
    if all(datafiles_checks):
        logging.info('Dataset exists and is processed.')
    elif all([not x for x in datafiles_checks]):
        # If checks are failed.
        new_download = True
    else:
        raise ValueError(
            'Dataset only partially processed. Try deleting {} and running again to download/process.'.format(os.path.join(dataset_dir)))

    # If need to download dataset, pass to appropriate downloader
    if new_download or force_download:
        logging.info('Dataset does not exist. Downloading!')
        if dataset.lower().startswith('qm9') or dataset.lower().startswith('zinc250k'):
            # TODO: pass rest of arguments
            download_and_process_dataset(datadir, dataset)
        else:
            raise ValueError(
                'Incorrect choice of dataset! Must choose qm9/zinc250k!')

    return datafiles


"""
this is how the structure of the dataset folders look like:

main_project_folder/
├── data/
│   ├── qm9/
│   │   ├── smiles/
│   │   │   ├── all.txt
│   │   │   ├── train.txt
│   │   │   ├── valid.txt
│   │   │   └── test.txt
│   │   ├── synthetic_coords_rdkit/
│   │   │   ├── train/mol_0.xyz, mol1_xyz, ....
│   │   │   ├── valid/mol_0.xyz, mol1_xyz, ....
│   │   │   └── test/mol_0.xyz, mol1_xyz, ....
│   │   ├── train.npz
│   │   ├── valid.npz
│   │   └── test.npz
└── README.md
"""

def download_and_process_dataset(datadir, dataname, skip_conformer_generaion=False, only_explicit_H=False, prop=None):
    """
    Download and prepare the requested dataset with synthetic coordinates from SMILES strings.
    """
    # Define directory for which data will be output.
    dataset_dir = join(*[datadir, dataname]) # 'data/qm9/' or 'data/zinc250k'

    # Important to avoid a race condition
    os.makedirs(dataset_dir, exist_ok=True)
    logging.info(f'Downloading and processing {dataname} dataset. Output will be in directory: {dataset_dir}.')

    dataset_smiles_dir = join(dataset_dir, 'smiles') # 'data/qm9/smiles'
    os.makedirs(dataset_smiles_dir, exist_ok=True)

    logging.info(f'Beginning download of {dataname} dataset!')
    if 'qm9' in dataname:
        dataset_smiles_url = 'https://raw.githubusercontent.com/THUNLP-MT/PS-VAE/main/data/qm9/'
    elif 'zinc250k' in dataname:
        dataset_smiles_url = 'https://raw.githubusercontent.com/THUNLP-MT/PS-VAE/main/data/zinc250k/'
    elif 'guacamol' in dataname:
        # taken from https://github.com/BenevolentAI/guacamol
        dataset_smiles_url = {
            'all': 'https://ndownloader.figshare.com/files/13612745',
            'train': 'https://ndownloader.figshare.com/files/13612760',
            'valid': 'https://ndownloader.figshare.com/files/13612766',
            'test': 'https://ndownloader.figshare.com/files/13612757',
        }
    else:
        raise NotImplementedError(f"Dataset {dataname} not supported yet")

    for partition in ['all', 'train', 'valid', 'test']:
        # file name on disk
        smiles_file = join(dataset_smiles_dir, partition+'.txt')
        if not os.path.exists(smiles_file):
            # file url to download from
            if isinstance(dataset_smiles_url, dict):
                file_url = dataset_smiles_url[partition]
            else:
                file_url = dataset_smiles_url+partition+'.txt'
            urllib.request.urlretrieve(file_url, filename=smiles_file)
        else:
            logging.info(f'The SMILES file {smiles_file} already exists. Skipping download!')
    logging.info(f'{dataname} SMILES dataset downloaded successfully!')

    logging.info('Computing synthetic coordinates for all splits...')
    for split in ['train', 'valid', 'test']:
        if not skip_conformer_generaion:
            logging.info(f'Computing synthetic coordinates for {split} split...')
            xyz_folder_name = smiles_to_xyz(dataset_dir, split, synthetic_coords_method='rdkit', only_explicit_H=only_explicit_H)
        else:
            # TODO: generalize
            xyz_folder_name = 'synthetic_coords_rdkit'
    
        logging.info(f' Processing {split} split...')
        process_dataset(dataset_dir, xyz_folder_name, split, only_explicit_H=only_explicit_H, prop=prop)
        if split == 'train':
            compute_dataset_info(datadir, dataname, only_explicit_H=only_explicit_H)
    logging.info('Finished Computing synthetic coordinates for all splits!')


# To run this, go to the main folder e3_diffusion and run
# PYTHONPATH="${PYTHONPATH}:." python qm9/data/prepare/download.py --datadir data/ --dataname qm9
# PYTHONPATH="${PYTHONPATH}:." python qm9/data/prepare/download.py --datadir data/ --dataname zinc250k_explicitH \
#       --only_explicit_H --skip_conformer_generaion --prop penalized_logP qed drd2 tpsa morgan_fingerprint

# For Guacamol dataset
# PYTHONPATH="${PYTHONPATH}:." python qm9/data/prepare/download.py --datadir data/ --dataname guacamol \
#       --only_explicit_H --prop penalized_logP qed drd2 tpsa morgan_fingerprint

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--datadir", default="data/")
    parser.add_argument("--dataname", default="qm9")
    parser.add_argument('--skip_conformer_generaion', action='store_true', default=False,
                        help='Skip conformer generation, which takes a long time and only reprocess existing files')
    parser.add_argument('--only_explicit_H', action='store_true', default=False,
                        help='Only Store explicit H atoms (that do not come from stereochemistry)')
    parser.add_argument("--prop", nargs='+', default=None, type=str,
                        help="properties to compute: penalized_logP | qed | drd2 | morgan_fingerprint")
    args = parser.parse_args()

    download_and_process_dataset(args.datadir, args.dataname, args.skip_conformer_generaion, only_explicit_H=args.only_explicit_H, prop=args.prop)
