import argparse
import multiprocessing as mp
import os.path as path
import time, datetime
import logging
from typing import List

def parse_arguments():
    parser = argparse.ArgumentParser()

    # path
    parser.add_argument('--data_dir', type=str, default='data/')
    parser.add_argument('--output_dir', type=str, default='output/')

    parser.add_argument('--log_dir', type=str, default='logs/')
    parser.add_argument('--preprocess_dir', default='preprocessed/guacamol/500code')
    parser.add_argument('--model_save_dir', type=str, default='ckpt/guacamol_500code')
    parser.add_argument('--tensorboard_dir', type=str, default='tensorboard/')
    
    parser.add_argument('--train_file', type=str, default='guacamol/train.smiles')
    parser.add_argument('--valid_file', type=str, default='valid.txt')
    parser.add_argument('--test_file', type=str, default='debug_tmp.txt')
    parser.add_argument('--tune_file', type=str, default='antibiotics.txt')
    parser.add_argument('--mols_pkl_path', type=str, default='mol_graphs.pkl')
    parser.add_argument('--train_processed_path', type=str, default='train.pth')
    parser.add_argument('--valid_processed_path', type=str, default='valid.pth')
    parser.add_argument('--vocab_processed_path', type=str, default='vocab.pth')

    parser.add_argument('--operation_path', type=str, default='guacamol_100k_code.txt')
    parser.add_argument('--vocab_path', type=str, default='guacamol_vocab_500code.txt')
    parser.add_argument('--generated_path', type=str, default='generated.smiles')
    parser.add_argument('--json_output_path', type=str, default='output_distribution_learning.json')

    parser.add_argument('--operation_learning_log_path', type=str, default='operation_learning.log')
    parser.add_argument('--vocab_construct_log_path', type=str, default='get_vocab.log')
    parser.add_argument('--train_log_file', type=str, default='train.log')
    parser.add_argument('--generate_log_file', type=str, default='generate.log')

    # hyperparameters
    ## learn_bpe, get_vocab
    parser.add_argument('--num_iters', type=int, default=3000)
    parser.add_argument('--min_frequency', type=int, default=10)
    parser.add_argument('--num_workers', type=int, default=min(mp.cpu_count(), 6))
    parser.add_argument('--mp_thd', type=int, default=1e5)
    parser.add_argument('--num_operations', type=int, default=500)

    ## props
    # parser.add_argument('--props', type=List[str], default=[])

    ## networks
    parser.add_argument('--hidden_size', default=256)
    parser.add_argument('--atom_embed_size', type=List[int], default=[192, 16, 16, 16, 16])
    parser.add_argument('--edge_embed_size', type=int, default=256)
    parser.add_argument('--motif_embed_size', type=int, default=[256, 256])
    parser.add_argument('--latent_size', type=int, default=256)
    parser.add_argument('--depth', type=int, default=15)
    parser.add_argument('--motif_depth', type=int, default=6)
    parser.add_argument('--num_props', type=int, default=4)
    parser.add_argument('--dropout', type=float, default=0.1)

    ## training
    parser.add_argument('--load_model_path', type=str, default=None)
    parser.add_argument('--load_from_data', action='store_true')
    parser.add_argument('--load_all', action='store_true')
    parser.add_argument('--hidden_layers', type=int, default=3)

    parser.add_argument('--batch_size', type=int, default=128)

    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--lr_anneal_iter', type=int, default=500)
    parser.add_argument('--lr_anneal_rate', type=float, default=0.99)
    parser.add_argument('--grad_clip_norm', type=float, default=1.0)

    parser.add_argument('--beta_schedule_mode', type=str, default="sigmoid")
    parser.add_argument('--beta_warmup', type=int, default=0)
    parser.add_argument('--beta_min', type=float, default=1e-3)
    parser.add_argument('--beta_max', type=float, default=0.6)
    parser.add_argument('--beta_anneal_period', type=int, default=20000)
    parser.add_argument('--beta_num_cycles', type=int, default=3)
    parser.add_argument('--prop_weight', type=float, default=0.5)
    
    # inference
    parser.add_argument('--greedy', action='store_true')
    parser.add_argument('--beam_top', type=int, default=0)
    parser.add_argument('--temperature', type=float, default=1.0)
    
    parser.add_argument('--seed', type=int, default=2)
    parser.add_argument('--train_batch_num', type=int, default=-1)
    parser.add_argument('--valid_batch_num', type=int, default=1)
    parser.add_argument('--test_batch_num', type=int, default=1)
    parser.add_argument('--job_name', type=str, default="")
    parser.add_argument('--epoch', type=int, default=5)
    parser.add_argument('--save_iter', type=int, default=1000)
    parser.add_argument('--eval_iter', type=int, default=10000)
    parser.add_argument('--num_samples', type=int, default=10000)
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--ckpt', type=str, default='best')
    parser.add_argument('--iter', type=int, default=1)

    # goal_directed
    parser.add_argument('--goal_id', type=int, default=0)
    
    args = parser.parse_args()
    return args