import torch


class torchdyn_wrapper(torch.nn.Module):

    """Wraps model to torchdyn compatible format."""

    def __init__(self, model, n_samples, node_mask, edge_mask, context):
        super().__init__()
        self.model = model
        self.n_samples = n_samples
        self.node_mask = node_mask
        self.edge_mask = edge_mask
        self.context = context

    def forward(self, t, x, args=None):
        # if t < 0:
        #     t = torch.Tensor([0.9990])
        print(f't: {t}')
        t_array = torch.full((self.n_samples, 1), fill_value=t.item(), device=x.device)
        return self.model.evaluate_vector_field(t_array, x, self.node_mask, self.edge_mask, self.context, fix_noise=False)
