import numpy as np
import pandas as pd
import torch
import time
import argparse
import importlib
import torch.nn.functional as F
import torch.nn as nn

def args(): 
    parser = argparse.ArgumentParser() 
    parser.add_argument('--device', type=str, default='cpu')  
    parser.add_argument('--filename', type=str, default='drPDE') 
    parser.add_argument('--hidden_size', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--outime', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--L', type=float, default=1)
    parser.add_argument('--num', type=int, default=1000)
    parser.add_argument('--sigma_sd', type=float, default=0.0)
    args = parser.parse_args(args=[])
    return(args)

args = args()
filename = args.filename

model_ind = 'FNO'
mod = importlib.import_module('utils.'+model_ind)
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = mod.Model


if __name__ == '__main__':
       
    print("model(args)...")
    experiment = model(args)
    
    start = time.time()
    print("training...")
    experiment.train()
    end = time.time()  
    print("time=",round(end-start,2))
    
    print("testing...")
    experiment.test()
    

    








