import argparse

from data.load_multi_data import SupportedMultiDatasets
from train.train_multi_conf import train_multi_conf

'''
scp 1500011335@115.27.161.31:yangshuwen/GeometryEncoding/log/geom_m9/* log/geom_m9
'''
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--generate', type=str, default='rdkit')
parser.add_argument('--derive', type=str, default='newton')
parser.add_argument('--compare', type=str, default='equiv-trunc')
args = parser.parse_args()
seed = args.seed
generate = args.generate
derive = args.derive
compare = args.compare

if seed:
    token = f'{generate}-{derive}-{compare}@{seed}'
    dataset_token = f'phi-psi@{seed}'
else:
    token = f'{generate}-{derive}-{compare}'
    dataset_token = 'phi-psi'

train_multi_conf(
    dataset_name=SupportedMultiDatasets.GEOM_QM9_SMALL,
    special_config={
        'GENERATE_TYPE': generate,
        'DERIVE_TYPE': derive,
        'COMPARE_TYPE': compare,
        'COMPARE_MIDDLE': derive != '',
        'LR': 5e-4 if generate == 'cvgae' else 5e-3,
        'GAMMA': 0.95 if generate == 'cvgae' else 0.95,
    },
    token=token,
    dataset_token=dataset_token,
    seed=seed,
    force_save=False,
    use_tqdm=True,
    use_cuda=True
)
