from src.pdedata import PDEDataBase


class TimePDEData(PDEDataBase):
    """
    Prepare data for the time-dependent PDEs.
    """
    def get_mass_matrix(self):
        """
        Return the mass matrix.
        """
        if not hasattr(self, "K"):
            raise AttributeError("The mass matrix is not computed.")
        return self.K
    
    def get_u0(self):
        """
        Return the initial state.
        """
        if not hasattr(self, "u0"):
            raise AttributeError("The initial state is not computed.")
        return self.u0

    def get_u0_prev(self):
        """
        Return the state previous to the initial state.
        """
        if not hasattr(self, "u0_prev"):
            raise AttributeError("The pre-initial state is not computed.")
        return self.u0_prev

    def get_du0(self):
        """
        Return the initial time derivative.
        """
        if not hasattr(self, "du0"):
            raise AttributeError("The initial time derivative is not computed.")
        return self.du0
    
    def compute_tm_rhs(self, X, t, dt, u):
        """
        Compute the time dependent rhs for current sub-interval.
        The rhs has the form: f(x, t, u).
            u: (n_time, n_points, output_dim)

        Note:
            If your rhs does not accept `u` as an input, you
            can ignore this method and implement `update_tm_rhs()`
            in your `TimeSteppingRunner` class.
            If your rhs can be separated into two parts,
            `f(x, t, u) = f1(x, t) + f2(x, t, u)`, you can
            implement `update_tm_rhs()` in your `TimeSteppingRunner`
            class to compute `f1` and implement `compute_tm_rhs()` 
            in your `TimePDEData` class to compute `f2`.
        """
        return None
