from temporal.training import *
import random
import numpy.random as np_random
import torch

from config import parser
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

args = parser.parse_args()

logger.info(args)

random.seed(args.seed)
np_random.seed(args.seed)
torch.manual_seed(args.seed)

if args.model_type == 'neuraltlp':
    if args.test_gen:
        if args.var_len:
            GeneratedTrainingVariable(args).train()
        else:
            GeneratedTraining(args).train()
    else:
        TemporalRelationTraining(args).train()
elif args.model_type == 'lstm':
    TemporalLSTMTraining(args).train()
elif args.model_type == 'mode':
    ModeBaseline(args).train()
elif args.model_type == 'map':
    if args.test_gen:
        GeneratedTrainingVariableMAP(args).train()
    else:
        TemporalMAPTraining(args).train()
else:
    raise ValueError(f'unkown model type f{args.model_type}')
