## main code file for training neurvec
import os

import torch
import numpy as np

from utils import mkdir_p
from utils.dataloader_segments import Dataset
from utils.parser import parse

from pde.ep import EP
from training import Train
from networks.stageCorr import StageCorrFC
from integrator.neurvec import NeurVec
from integrator.RK4 import RK4


torch.set_default_dtype(torch.float64)

args, state = parse()



if not os.path.isdir(args.ckpt):
    mkdir_p(args.ckpt)


def main():
    
    # PDE info
    data = np.load(args.train_dir)
    
    g = float(data['g'])
    k_over_mg = float(data['k_over_mg'])
    l0 = float(data['l0'])
    pde = EP(g = g, k_over_mg = k_over_mg, l0 = l0)
    
    # networks
    nns = []
    for _ in range(1):
        nns.append(StageCorrFC(4, nhidden = 128))
    
    # setup training
    session = Train(args)
    integrator = NeurVec(nns, pde)
    reference = RK4(pde)
    
    # train models
    session.fit(integrator, reference)


if __name__ == '__main__':
    main()

