from integrator.base import Solver
from pde.base import PDE
import torch
import scipy.sparse as sps

class RK4(Solver):
    
    # for 1D Dirichlet/Periodic problem
    def __init__(self, pde: PDE = None, device = 'auto'):
        super().__init__(pde, device)
    
    
    def set_problem(self, pde: PDE):
        self.t = pde.t0
        self.n_step = 0

        pde.attach_solver(self)
        self.has_pde = True
    
    
    def set_ic(self, ic):
        self.x = torch.as_tensor(ic)
        self.x = self.x.to(self.device)
        self.t = 0
        self.n_step = 0
        
    
    def rhs(self):
        # to be updated by set_problem
        pass
        
    
    def step_solution(self, dt):
        k1 = self.rhs(self.x, self.t)
        k2 = self.rhs(self.x + dt*k1/2, self.t + dt/2)
        k3 = self.rhs(self.x + dt*k2/2, self.t + dt/2)
        k4 = self.rhs(self.x + dt*k3, self.t + dt)
        return self.x + dt/6 * (k1 + 2 * k2 + 2 * k3 + k4)
    
    
    def simulate(self, dt, x0 = None, T = None, n_step = None, callback = None, *args, **kwargs):
        if not self.has_pde:
            raise ValueError('No PDE attached.')
        if T is None:
            if n_step is None:
                raise ValueError('both T and n_step are None')
            T = self.t + dt * (n_step - 0.5)
        if x0 is not None:
            self.x = torch.as_tensor(x0)
            self.x = self.x.to(self.device)
        while self.t < T:
            self.step(dt)
            if callback is not None:
                callback(self, *args, **kwargs)
    
    
    