import numpy as np
import pandas as pd
import torch
import time
import argparse
import importlib

def args(): 
    parser = argparse.ArgumentParser() 
    parser.add_argument('--device', type=str, default='cpu')  
    parser.add_argument('--filename', type=str, default='drPDE') 
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--batch_num', type=int, default=10)
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--outime', type=int, default=50)
    parser.add_argument('--lr', type=float, default=1e-6)
    parser.add_argument('--thr', type=float, default=1e-7)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--rtol', type=float, default=1e-6)
    parser.add_argument('--atol', type=float, default=1e-8)
    parser.add_argument('--delta_t1', type=float, default=0.1)
    parser.add_argument('--delta_t2', type=float, default=0.001)
    parser.add_argument('--l1_alpha', type=float, default=0)
    parser.add_argument('--L', type=float, default=1)
    parser.add_argument('--method1', type=str, default="euler")
    parser.add_argument('--method2', type=str, default="dopri5")
    parser.add_argument('--method3', type=str, default="dopri5")
    parser.add_argument('--adjoint', type=bool, default=False)
    parser.add_argument('--pretr', type=bool, default=False)
    parser.add_argument('--tle1', type=int, default=2)
    parser.add_argument('--tle2', type=int, default=40)
    parser.add_argument('--warm_up', type=int, default=50)
    parser.add_argument('--num', type=int, default=10)
    parser.add_argument('--sigma_sd', type=float, default=0.0)
    args = parser.parse_args(args=[])
    return(args)


#{"dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun", "euler", 
#"midpoint", "rk4", "explicit_adams", "implicit_adams", "fixed_adams", "scipy_solver"}
    
args = args()
filename = args.filename
    
model_ind = 'PDENET'
mod = importlib.import_module('utils.'+model_ind)
model = mod.Model
    
if __name__ == '__main__':
    
    start = time.time()
    
    print("model(args)...")
    experiment = model(args)
    
    print("training...")
    experiment.train()
    
    print("testing...")
    experiment.test()
    
    end = time.time()  
    print("time=",round(end-start,2))
    
    
















