from abc import abstractmethod 
import torch


class Solver:
    
    def __init__(self, pde = None, device = 'auto', models = []):
        self.x = None
        self.t = None
        self.n = None
        self.n_step = None
        self.dx = None
        self.has_pde = False
        try:
            iter(models)
        except:
            models = [models]
        self.models = models
        
        if device == 'auto':
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
        if pde is not None:
            self.set_problem(pde)
    
    
    def Dx(self, x):
        pass

    
    def Dxx(self, x):
        pass
    
    
    def step(self, dt):
        self.x = self.step_solution(dt)
        self.t += dt
        self.n_step += 1
    
    
    def set_ic(self, ic):
        self.x = ic
    
    
    def get_models(self):
        return self.models
        
    
    @abstractmethod
    def set_problem(self, pde):
        pass
        
    
    @abstractmethod
    def rhs(self):
        # evaluate dudt
        pass
    
    
    @abstractmethod
    def step_solution(self):
        # step to next time slice, update self.x
        pass
    
    
    @abstractmethod
    def simulate(self, dt, x0 = None, T = None, n_step = None, callback = None):
        pass