import json
import torch

from src.tmrunner import TimeSteppingRunner
from src.ns.ns2d_lt.pdedata import NS2D_LT


class NS2D_LTRunner(TimeSteppingRunner):
    """
    Run the 2D Navier-Stokes equation (NS2D_LT).
    """
    def __init__(self, device="cuda") -> None:
        super().__init__(device=device)
        self.pde_data: NS2D_LT = NS2D_LT()
    
    def load_config(self):
        # Load from json file
        with open("src/ns/ns2d_lt/conf.json", "r") as f:
            self.config = json.load(f)

    def update_tm_rhs(self, ts):
        self.pde_data.init_rhs(ts)
        self.b = torch.tensor(self.pde_data.get_rhs(),
            dtype=torch.float32, device=self.device)

    def setup(self):
        super().setup()
        self.X.requires_grad_(True)
