import numpy as np
import pandas as pd
import torch
import time
import argparse
import importlib
    
def args(): 
    parser = argparse.ArgumentParser() 
    parser.add_argument('--device', type=str, default='cpu')  
    parser.add_argument('--filename', type=str, default='drPDE') 
    parser.add_argument('--batch_size', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=10000)
    parser.add_argument('--outime', type=int, default=40)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--thr', type=float, default=1e-6)
    
    parser.add_argument('--tle', type=int, default=950)
    parser.add_argument('--numz', type=int, default=512)
    parser.add_argument('--sigma', type=str, default=0.0)
    parser.add_argument('--prob', type=int, default=8)
    
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--L', type=float, default=1)
    parser.add_argument('--pretr', type=bool, default=False) 
    parser.add_argument('--num', type=int, default=100)
    parser.add_argument('--sigma_sd', type=float, default=0.0)
    parser.add_argument('--rtol', type=float, default=1e-6)
    parser.add_argument('--atol', type=float, default=1e-8)
    parser.add_argument('--method', type=str, default="dopri5")
    parser.add_argument('--adjoint', type=bool, default=False) 
    args = parser.parse_args(args=[])
    return(args)
    
#{"dopri8", "dopri5", "bosh3", "fehlberg2", "adaptive_heun", "euler", 
#"midpoint", "rk4", "explicit_adams", "implicit_adams", "fixed_adams", "scipy_solver"}
    
args = args()
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
filename = args.filename

'''
E3dr_FNODE: the proposed method
E3dr_NODE: NODE
E3dr_FNODE_resolution: resolution-invariant
'''

model_ind = 'E3dr_NODE'
mod = importlib.import_module('utils.'+model_ind)
model = mod.Model
args.model_ind = model_ind

if __name__ == '__main__':
    
    print("model(args)...")
    experiment = model(args)
    
    print("training...")
    start = time.time()
    #
    experiment.train()
    #
    end = time.time()  
    print("training time=",round(end-start,2))
    
    print("testing...")
    experiment.test()
    
    
    
    













