import argparse
import os
import torch
import sys


parser = argparse.ArgumentParser()

# ===================== gpu id ===================== #
parser.add_argument('--gpu', type=int, default=-1)

# =================== random seed ================== #
parser.add_argument('--seed', type=int, default=1234)

# ===================== Retro* alg =================== #
parser.add_argument('--iterations', type=int, default=50)
parser.add_argument('--expansion_beam', type=int, default=50)
parser.add_argument('--expansion_topk', type=int, default=50)

# ==================== test time ==================== #
parser.add_argument('--alg', default='polyretro')        # random / uspto / polyretro

# ================= experiment config =============== #
parser.add_argument('--n_cores', type=int, default=12)
parser.add_argument('--start', type=int, default=0)
parser.add_argument('--end', type=int, default=1e10)
parser.add_argument('--data_root', default='polyretro-data')

args = parser.parse_args()

# setup device
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)

# root dirs
root = args.data_root

# enumerative search related
args.enum_search_result_dir = '%s/enumerative_result' % root
if not os.path.exists(args.enum_search_result_dir):
    os.mkdir(args.enum_search_result_dir)
args.output_dir = '%s/synthetic_polymer_result' % args.enum_search_result_dir
if not os.path.exists(args.output_dir):
    os.mkdir(args.output_dir)
args.test_output_dir = '%s/real_polymer_result' % args.enum_search_result_dir
if not os.path.exists(args.test_output_dir):
    os.mkdir(args.test_output_dir)

# processed result
args.final_result_dir = '%s/processed_result' % root
if not os.path.exists(args.final_result_dir):
    os.mkdir(args.final_result_dir)

args.monomer_result_dir = '%s/monomer_result' % root
if not os.path.exists(args.monomer_result_dir):
    os.mkdir(args.monomer_result_dir)

# dataset related
args.dataset_dir = '%s/polymer_dataset' % root
args.real_dataset = '%s/real_data.json' % args.dataset_dir
args.test_real_polymers = '%s/processed_real_polymer.csv' % args.dataset_dir

# starting molecules
args.starting_molecules = '%s/uspto_rand_split_routes/all/origin_dict.csv' % root

# mlp model
MLP_root = '%s/mlp_model_dumps' % root
args.mlp_model_dump = '%s/uspto_train/saved_rollout_state_1_2048.ckpt' % MLP_root
args.mlp_templates = '%s/template_rules_1.dat' % MLP_root
args.mlp_model_dump_r0 = '%s/uspto_all/saved_rollout_state_all_r0_2048.ckpt' % MLP_root
args.mlp_templates_r0 = '%s/templates_rule_r0_all.dat' % MLP_root
