import sys
import os
import time
import numpy as np
import torch
from torch import nn, optim
from model import SparseCL
from util import Logger
import config

def main(args):
    model = SparseCL(args).to(args.device)
    if args.stage==0:
        ###training stage
        model.train_model(args)
    elif args.stage==1:
        model.linear_model(args)
        
if __name__ == '__main__':
    args = config.parse_arg()
    sys.stdout = Logger(args)
    dict_args = vars(args)
    for k, v in zip(dict_args.keys(), dict_args.values()):
        print("{0}: {1}".format(k, v))
    
    main(args)