import torch

class torch_wrapper_tv(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        # print(x.shape, t.shape)
        # print(t)
        # return self.model(x)
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))