import torch
import numpy as np

from data import get_the_data
from train import train
from model import define_the_model
from test import test
from optim import define_optimizer
import input_args
from stamp import param_stamp

def run(args):

    args.param_stamp = param_stamp(args)
    print(args.param_stamp)
    cuda = torch.cuda.is_available() and args.cuda
    device = torch.device("cuda" if cuda else "cpu")
    args.device = device

    # Set random seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if cuda:
        torch.cuda.manual_seed(args.seed)


    train_data, test_data = get_the_data(args)

    model = define_the_model(args)

    optim = define_optimizer(args, model)

    train(args, model, train_data, optim)

    test(args, model, test_data)


## Function for specifying input-options and organizing / checking them
def handle_inputs():
    # Define input options
    parser = input_args.define_args(filename="main", description='Train & test the generative classifier.')
    parser = input_args.add_options(parser)
    # Parse, process (i.e., set defaults for unselected options) and check chosen options
    args = parser.parse_args()
    input_args.set_defaults(args)
    return args



if __name__ == '__main__':
    args = handle_inputs()
    run(args)