## NN model NeurVec

import torch


from integrator.RK4 import RK4
from pde.base import PDE
import torch

class NeurVec(RK4):
    
    # for 1D Dirichlet/Periodic problem
    def __init__(self, nns, pde: PDE = None, device = 'auto'):
        super().__init__(pde, device)
        self.models = nns
        for model in self.models:
            model.to(self.device)
        
    
    def step_solution(self, dt):
        k1 = self.rhs(self.x, self.t)
        k2 = self.rhs(self.x + dt*k1/2, self.t + dt/2)
        k3 = self.rhs(self.x + dt*k2/2, self.t + dt/2)
        k4 = self.rhs(self.x + dt*k3, self.t + dt)
        finalcorr = self.models[0](self.x.T).T
        return self.x + dt/6 * (k1 + 2 * k2 + 2 * k3 + k4 + finalcorr)
    
    
    
