# use this to validate if the reverse process is correct
import torch


class OmnipotentModel(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(OmnipotentModel, self).__init__()
        self.dummy_layer = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x, t, y=None, **kwargs):
        x0 = kwargs["kwargs"]["x0"]
        return x0
