import json

from src.runner import RunnerBase
from src.poisson.poisson2d_ms.pdedata import Poisson2D_MS
from utils.fourier_mlp import FourierMLP


class Poisson2D_MSRunner(RunnerBase):
    """
    Run the 2D Poisson equation with Many Subdomains (Poisson2d-MS).
    """

    def __init__(self, device="cuda") -> None:
        super().__init__(device=device)
        self.pde_data = Poisson2D_MS()

    def load_config(self):
        # Load from json file
        with open("src/poisson/poisson2d_ms/conf.json", "r") as f:
            self.config = json.load(f)

    def init_model(self):
        self.model = FourierMLP(
            input_dim=self.config["Spatial Temporal Dimension"],
            output_dim=self.config["Output Dimension"],
            n_layers=5, n_hidden=128, _type='gaussian', 
            fourier_dim=128, sigma=0.5).to(self.device)
