from models.s4_model import S4BaseModel

class S4NDModel(S4BaseModel):
    def __init__(
        self,
        d_input=1,
        d_output=1,
        d_model=256,
        dim=1,
        n_layers=4,
        exo_dropout=0.0,
        prenorm=False,
        interlayer_act=None,
        s4block_args={},
        input_processor="ConcatND",
        output_processor="UnflatTrans",
        **kwargs
    ):
        d_input = d_input + 1 # +1 for the grid
        s4block_args["dim"] = dim
        super().__init__(        
            d_input,
            d_output=d_output,
            d_model=d_model,
            n_layers=n_layers,
            exo_dropout=exo_dropout,
            prenorm=prenorm,
            interlayer_act=interlayer_act,
            input_processor=input_processor,
            output_processor=output_processor,
            **s4block_args)