import argparse

import itertools
from datetime import datetime

import Models.pl_gnn as gnn
import Models.pl_rnn as rnn
from config.experiments import *

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gnn", action='store_true')

    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()

    # let the model add what it wants
    if temp_args.gnn:
        parser = gnn.add_args(parser)
    else:
        parser = rnn.add_args(parser)

    args = parser.parse_args()
    dict_args = vars(args)
    if args.gnn:
        out_folder = 'logs_gnn/' + datetime.now().strftime('%d.%m.%y_%H:%M')
        gnn.train_gnn(out_folder=out_folder, **dict_args)
    else:
        out_folder = 'logs/' + datetime.now().strftime('%d.%m.%y_%H:%M')
        del dict_args['model']
        for model, size in itertools.product(all_models, all_sizes):
            rnn.train(model=model, model_size=size, out_folder=out_folder, **dict_args)
