import torch
import torch.fft as fourier
torch.set_default_dtype(torch.float64)

from integrator.edtrk4 import EDTRK4

class StageCorrEDTRK4(EDTRK4):
    
    def __init__(self, L, N, M, dt, nns, device = 'auto'):
        super().__init__(L, N, M, dt, device, nns)
        for model in self.models:
            model.to(self.device)
    
    def step_solution(self):
        v = fourier.fft(self.u, dim = 0)
        v0 = self.u**2
        v0 = v0 + self.models[0](v0.T).T
        Nv = self.g * fourier.fft(v0, dim = 0)
        
        a = self.E2 * v + self.Q * Nv 
        a0 = torch.real(fourier.ifft(a, dim = 0))**2 
        a0 = a0 + self.models[1](a0.T).T
        Na = self.g * fourier.fft(a0, dim = 0)
        
        b = self.E2 * v + self.Q * Na
        b0 = torch.real(fourier.ifft(b, dim = 0))**2
        b0 = b0 + self.models[2](b0.T).T 
        Nb = self.g * fourier.fft(b0, dim = 0)
        
        c = self.E2 * a + self.Q * (2 * Nb - Nv)
        c0 = torch.real(fourier.ifft(c, dim = 0))**2
        c0 = c0 + self.models[3](c0.T).T
        Nc = self.g * fourier.fft(c0, dim = 0)
        
        v = self.E * v + Nv * self.f1 + 2 * (Na + Nb) * self.f2 + Nc * self.f3
        u = torch.real(fourier.ifft(v, dim = 0))
        return u
    