import argparse

from data.load_data import SupportedDatasets
from train.train_conf import train_conf

'''
scp 1500011335@115.27.161.31:yangshuwen/GeometryEncoding/log/qm8-conf/* log/qm8-conf
'''
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--derive', type=str, default='langevin')
parser.add_argument('--compare', type=str, default='adj3')
args = parser.parse_args()
seed = args.seed
derive = args.derive
compare = args.compare

if seed:
    token = f'lstm-{derive}-{compare}@{seed}'
    dataset_token = f'phi-psi@{seed}'
else:
    token = f'lstm-{derive}-{compare}'
    dataset_token = 'phi-psi'

train_conf(
    dataset_name=SupportedDatasets.QM8,
    special_config={
        'DERIVE_TYPE': derive,
        'COMPARE_TYPE': compare,
    },
    token=token,
    dataset_token=dataset_token,
    seed=seed,
    force_save=False,
    use_tqdm=False,
    use_cuda=True
)
