from integrator.base import Solver
from pde.base import PDE
import torch
import scipy.sparse as sps

class RK4(Solver):
    
    # for 1D Dirichlet/Periodic problem
    def __init__(self, pde: PDE = None, device = 'auto'):
        super().__init__(pde, device)
        
    
    def Dx(self, x):
        return self.D1 @ x
    
    
    def Dxx(self, x):
        return self.D2 @ x
    
    
    def set_problem(self, pde: PDE):
        self.n = pde.n  # does not include boundary data
        self.t = pde.t0
        self.dx = pde.dx
        self.n_step = 0
        
        self.D1 = torch.tensor(
            (sps.diags([-1, 0, 1], [-1, 0, 1], shape = (self.n, self.n)) / (2 * self.dx)).todense()
            )  
        
        self.D1[0, -1] = self.D1[1, 0]
        self.D1[-1, 0] = self.D1[0, 1]
        self.D1 = self.D1.to_sparse_csr()
        self.D1 = self.D1.to(self.device)
        
        self.D2 = torch.tensor(
            (sps.diags([1, -2, 1], [-1, 0, 1], shape = (self.n, self.n)) / (self.dx**2)).todense()
            )
        
        self.D2[0, -1] = self.D2[1, 0]
        self.D2[-1, 0] = self.D2[0, 1]
        self.D2 = self.D2.to_sparse_csr()
        self.D2 = self.D2.to(self.device)

        pde.attach_solver(self)
        self.has_pde = True
    
    
    def set_ic(self, ic):
        self.x = torch.as_tensor(ic)
        self.x = self.x.to(self.device)
        self.t = 0
        
    
    def rhs(self):
        # to be updated by set_problem
        pass
        
    
    def step_solution(self, dt):
        k1 = self.rhs(self.x, self.t)
        k2 = self.rhs(self.x + dt*k1/2, self.t + dt/2)
        k3 = self.rhs(self.x + dt*k2/2, self.t + dt/2)
        k4 = self.rhs(self.x + dt*k3, self.t + dt)
        return self.x + dt/6 * (k1 + 2*k2 + 2*k3 + k4)

    
    
    def simulate(self, dt, x0 = None, T = None, n_step = None, callback = None, *args, **kwargs):
        if not self.has_pde:
            raise ValueError('No PDE attached.')
        if T is None:
            if n_step is None:
                raise ValueError('both T and n_step are None')
            T = self.t + dt * (n_step - 0.5)
        if x0 is not None:
            self.x = torch.as_tensor(x0)
            self.x = self.x.to(self.device)
        while self.t < T:
            self.step(dt)
            if callback is not None:
                callback(self, *args, **kwargs)
    
    
    