## main code file for training neurvec
import os

import torch
import numpy as np

from utils import mkdir_p
from utils.parser import parse

from training import Train
from networks.stage_corr import StageCorr
from integrator.edtrk4_neurvec import NeurVecEDTRK4
from integrator.edtrk4 import EDTRK4


torch.set_default_dtype(torch.float64)

args, state = parse()



if not os.path.isdir(args.ckpt):
    mkdir_p(args.ckpt)


def main():
    # PDE info
    data = np.load(args.train_dir)
    N = torch.tensor(data['N'])
    M = torch.tensor(data['M'])
    L = torch.tensor(data['L'])
    dt = torch.tensor(data['dt'])
    
    # networks
    nns = []
    for _ in range(1):
        nns.append(StageCorr(N, nhidden = 1024))
    
    # setup training
    session = Train(args)
    integrator = NeurVecEDTRK4(L, N, M, dt, nns)
    reference = EDTRK4(L, N, M, dt)
    
    # train models
    session.fit(integrator, reference)


if __name__ == '__main__':
    main()

