import json

from src.tmrunner import TimeSteppingRunner
from src.burgers.burgers2d_c.pdedata import Burgers2D_C
from utils.mlp import MLP


class Burgers2D_CRunner(TimeSteppingRunner):
    """
    Run the 2D Burgers' equation (Burgers2d-C).
    """
    def __init__(self, device="cuda") -> None:
        super().__init__(device=device)
        self.pde_data = Burgers2D_C()
    
    def load_config(self):
        # Load from json file
        with open("src/burgers/burgers2d_c/conf.json", "r") as f:
            self.config = json.load(f)
    
    def init_model(self):
        self.model = MLP(
            layer_sizes=[self.config["Spatial Dimension"]] + 
            [128] * 5 + [self.config["Output Dimension"]* 
                self.config["Grid Size (t-direction)"]]).to(self.device)

    def setup(self):
        super().setup()
        self.X.requires_grad_(True)
