import numpy as np
import pandas as pd
import torch
import time
import argparse
import importlib
from torch.multiprocessing import Pool

def args(): 
    parser = argparse.ArgumentParser() 
    parser.add_argument('--device', type=str, default='cpu')  
    parser.add_argument('--filename', type=str, default='E2') 
    parser.add_argument('--tle', type=int, default=950)
    parser.add_argument('--ole', type=int, default=20)
    parser.add_argument('--numz', type=int, default=500)
    parser.add_argument('--sigma', type=str, default=0.0)
    parser.add_argument('--prob', type=int, default=8)
    parser.add_argument('--N_dim', type=int, default=4)
    
    parser.add_argument('--epochs', type=int, default=5000)
    parser.add_argument('--outime', type=int, default=100)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--hidden_size', type=int, default=64)
    parser.add_argument('--batch_size', type=int, default=3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--rtol', type=float, default=1e-6)
    parser.add_argument('--atol', type=float, default=1e-8)
    parser.add_argument('--method', type=str, default="dopri5")
    parser.add_argument('--adjoint', type=bool, default=False)
    parser.add_argument('--num', type=int, default=100)
    parser.add_argument('--inte', type=int, default=1)
    parser.add_argument('--sigma_sd', type=float, default=0.0)
    args = parser.parse_args(args=[])
    return(args)

args = args()
filename = args.filename

args.model_ind = 'E2_FNODE'
mod = importlib.import_module('utils.'+args.model_ind)
model = mod.Model
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if __name__ == '__main__':
    experiment = model(args)
    
    start = time.time()
    print("training...")
    experiment.train()
    end = time.time()  
    print("Train time=",round(end-start,2))
    
    print("testing...")
    experiment.test()
    
    print("drawing...")
    experiment.draw()    
    










