import torch
import torch.nn as nn
import numpy as np
import math
import torch.fft as fourier
torch.set_default_dtype(torch.float64)

class EDTRK4:
    def __init__(self, L, N, M, dt, device = 'auto', models = []):
        if device == 'auto':
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        try:
            iter(models)
        except:
            models = [models]
        self.models = models
    
        LL = L
        h = dt

        k = np.arange(0, N//2).tolist()+[0]+ np.arange(-N//2+1, 0).tolist()
        k = [v*(2*math.pi/LL) for v in k]
        k = torch.tensor(k).view(-1, 1)

        L = k ** 2 - k ** 4  # L reveals the nature of the PDE
        E = torch.exp(h * L)
        E2 = torch.exp(h * L / 2)

        # roots of unity
        j = torch.view_as_complex(torch.tensor([0.0, 1.0]))
        r = torch.exp(j * math.pi * (torch.arange(1, M + 1) - .5) / M).view(1, -1)  

        LR = h * L[:, [0] * M] + r[[0] * N, :]

        Q = h * torch.real(torch.mean((torch.exp(LR / 2) - 1) / LR, dim=1, keepdim=True))
        f1 = h * torch.real(torch.mean((-4 - LR + torch.exp(LR) * (4 - 3 * LR + LR ** 2)) / LR ** 3, 1, keepdim=True))
        f2 = h * torch.real(torch.mean((2 + LR + torch.exp(LR) * (-2 + LR)) / LR ** 3, 1, keepdim=True))
        f3 = h * torch.real(torch.mean((-4 - 3 * LR - LR ** 2 + torch.exp(LR) * (4 - LR)) / LR ** 3, 1, keepdim=True))
        g = -0.5 * j * k

        self.E = E.to(self.device)
        self.E2 = E2.to(self.device)
        self.Q = Q.to(self.device)
        self.f1 = f1.to(self.device)
        self.f2 = f2.to(self.device)
        self.f3 = f3.to(self.device)
        self.g = g.to(self.device)
        self.dt = dt
        
        self.t = 0
        self.n_step = 0
        self.u = None
        
        
    def step_solution(self):
        v = fourier.fft(self.u, dim = 0)
        v0 = self.u**2
        Nv = self.g * fourier.fft(v0, dim = 0)
        
        a = self.E2 * v + self.Q * Nv 
        a0 = torch.real(fourier.ifft(a, dim = 0))**2 
        Na = self.g * fourier.fft(a0, dim = 0)
        
        b = self.E2 * v + self.Q * Na
        b0 = torch.real(fourier.ifft(b, dim = 0))**2 
        Nb = self.g * fourier.fft(b0, dim = 0)
        
        c = self.E2 * a + self.Q * (2 * Nb - Nv)
        c0 = torch.real(fourier.ifft(c, dim = 0))**2
        Nc = self.g * fourier.fft(c0, dim = 0)
        
        v = self.E * v + Nv * self.f1 + 2 * (Na + Nb) * self.f2 + Nc * self.f3 
        u = torch.real(fourier.ifft(v, dim = 0))
        return u

    
    def set_ic(self, ic):
        self.u = torch.as_tensor(ic)
        if self.u.ndim == 1:
            self.u = self.u.reshape(-1, 1)
        self.u = self.u.to(self.device)
        self.t = 0
        self.n_step = 0
    
    
    def step(self):
        self.u = self.step_solution()
        self.t += self.dt
        self.n_step += 1
    
    
    def get_models(self):
        return self.models
    
    
    def simulate(self, u0 = None, T = None, n_step = None, callback = None, *args, **kwargs):
        if T is None:
            if n_step is None:
                raise ValueError('both T and n_step are None')
            T = self.t + self.dt * (n_step - 0.5)
        if u0 is not None:
            self.u = torch.as_tensor(u0)
            self.u = self.u.to(self.device)
        while self.t < T:
            self.step()
            if callback is not None:
                callback(self, *args, **kwargs)
