## main code file for training neurvec
import os

import torch
import numpy as np

from utils import mkdir_p
from utils.parser import parse

from training import Train
from networks.stage_corr import StageCorr
from integrator.edtrk4_stage import StageCorrEDTRK4
from integrator.edtrk4 import EDTRK4



torch.set_default_dtype(torch.float64)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False



args, state = parse()


if not os.path.isdir(args.ckpt):
    mkdir_p(args.ckpt)


def main():
    # PDE info
    data = np.load(args.train_dir)
    N = torch.tensor(data['N'])
    M = torch.tensor(data['M'])
    L = torch.tensor(data['L'])
    dt = torch.tensor(data['dt'])
    
    # networks
    nns = []
    for _ in range(4):
        nns.append(StageCorr(N, nhidden = 256))
    
    # setup training
    session = Train(args)
    integrator = StageCorrEDTRK4(L, N, M, dt, nns)
    reference = EDTRK4(L, N, M, dt)
    
    # train models
    session.fit(integrator, reference)


if __name__ == '__main__':
    main()

