## main code file for training neurvec
import os

import torch
import numpy as np

from utils import mkdir_p
from utils.parser import parse

from pde.burgers import Burgers
from training import Train
from networks.stageCorrFC_modified import StageCorrFC
from integrator.stage_corr import StageCorrRK4
from integrator.RK4 import RK4



torch.set_default_dtype(torch.float64)

args, state = parse()


if not os.path.isdir(args.ckpt):
    mkdir_p(args.ckpt)


def main():
    # PDE info
    data = np.load(args.train_dir)
    xs = data['xs']
    n = len(xs)
    nu = torch.tensor(data['nu'])
    pde = Burgers(xs, nu = nu)
    
    # networks
    nns = []
    for _ in range(4):
        nns.append(StageCorrFC(n, nhidden = 256))
    
    # setup training
    session = Train(args)
    integrator = StageCorrRK4(nns, pde)
    reference = RK4(pde)
    
    # train models
    session.fit(integrator, reference)


if __name__ == '__main__':
    main()

