import json

from src.tmrunner import TimeSteppingRunner
from src.wave.wave2d_cg.pdedata import Wave2D_CG


class Wave2D_CGRunner(TimeSteppingRunner):
    """
    Run the 2D wave equation (Wave2d-CG).
    """
    def __init__(self, device="cuda") -> None:
        super().__init__(device=device)
        self.pde_data = Wave2D_CG()
    
    def load_config(self):
        # Load from json file
        with open("src/wave/wave2d_cg/conf.json", "r") as f:
            self.config = json.load(f)
