import torch
import torch.nn as nn

class ODEfunc(nn.Module):
    def __init__(self, diffeq):
        super(ODEfunc, self).__init__()
        self.diffeq = diffeq
        self.register_buffer("_num_evals", torch.tensor(0.))

    def before_odeint(self):
        self._num_evals.fill_(0)

    def forward(self, t, states):
        y = states[0]
        self._num_evals += 1
        for state in states:
            state.requires_grad_(True)

        with torch.set_grad_enabled(True):
            assert len(states) == 5 
            y, conditional_state, edge_idx_all, location, t_past = states
            dy = self.diffeq(t, y, conditional_state, edge_idx_all, location, t_past)
            _d1, _d2, _d3, _d4 = torch.zeros_like(conditional_state).requires_grad_(True), \
                                 torch.zeros_like(edge_idx_all).requires_grad_(True), \
                                 torch.zeros_like(location).requires_grad_(True), \
                                 torch.zeros_like(t_past).requires_grad_(True)
            return dy, _d1, _d2, _d3, _d4
