import json

from src.invrunner import InversePDERunner
from src.inverse.pinv.pdedata import PInv
from utils.mlp import MLP


class PInvRunner(InversePDERunner):
    """
    Run the Poisson inverse problem (PInv).
    """

    def __init__(self, device="cuda") -> None:
        super().__init__(device=device)
        self.pde_data = PInv()

    def load_config(self):
        # Load from json file
        with open("src/inverse/pinv/conf.json", "r") as f:
            self.config = json.load(f)
    
    def init_model(self):
        self.model_u = MLP(
            [self.config["Spatial Temporal Dimension"]] +
            [64] * 3 +
            [self.config["Output Dimension"]]).to(self.device)
        self.model_a =  MLP(
            [self.config["Spatial Temporal Dimension"]] +
            [128] * 5 +
            [self.config["Output Dimension"]]).to(self.device)
