import numpy as np
import pandas as pd
import torch
import time
import argparse
import importlib

def args(): 
    parser = argparse.ArgumentParser() 
    parser.add_argument('--device', type=str, default='cpu')  
    parser.add_argument('--filename', type=str, default='drPDE') 
    parser.add_argument('--m', type=int, default=100)
    parser.add_argument('--hidden_size', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=10000)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--outime', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--num', type=int, default=10)
    parser.add_argument('--sigma_sd', type=float, default=0.0)
    args = parser.parse_args(args=[])
    return(args)

args = args()
filename = args.filename

model_ind = 'DON'
mod = importlib.import_module('utils.'+model_ind)
model = mod.Model

if __name__ == '__main__':
    
    if torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')
    
    start = time.time()
    
    experiment = model(args)
    
    print("model(args)...")
    print("training...")
    experiment.train()
    end = time.time()  
    print("time=",round(end-start,2))
    
    print("testing...")
    experiment.test()
    

    
    











