## NN model NeurVec

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

from integrator.RK4 import RK4
from pde.base import PDE

class StageCorrRK4(RK4):
    
    # for 1D Dirichlet/Periodic problem
    def __init__(self, nns, pde: PDE = None, device = 'auto'):
        super().__init__(pde, device)
        self.models = nns
        for model in self.models:
            model.to(self.device)
        
    
    def step_solution(self, dt):
        k1 = self.rhs(self.x, self.t)
        k1 = k1 + self.models[0](k1.T).T
        k2 = self.rhs(self.x + dt*k1/2, self.t + dt/2)
        k2 = k2 + self.models[1](k2.T).T
        k3 = self.rhs(self.x + dt*k2/2, self.t + dt/2)
        k3 = k3 + self.models[2](k3.T).T
        k4 = self.rhs(self.x + dt*k3, self.t + dt)
        k4 = k4 + self.models[3](k4.T).T
        return self.x + dt/6 * (k1 + 2 * k2 + 2 * k3 + k4)
    
    
    
