import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchdiffeq import odeint_adjoint as odeint
import numpy as np
import pdb

class ODEFunc(nn.Module):
    def __init__(self, hidden_dim):
        super(ODEFunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_dim, 50),
            nn.Tanh(),
            nn.Linear(50, hidden_dim)
        )

    def forward(self, t, x):
        return self.net(x)

class ODERNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, args):
        super(ODERNN, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.args = args

        self.gru_cell = nn.GRUCell(input_dim, hidden_dim)
        self.ode_func = ODEFunc(hidden_dim)
        self.output_layer = nn.Linear(hidden_dim, self.input_dim)

    def forward(self, x, times):
        batch_size = x.size(0)
        seq_len = x.size(1)
        h = torch.zeros(batch_size, self.hidden_dim).to(x.device).double()

        outputs = []
        for i in range(seq_len):
            if i > 0:
                time_deltas = times[:, i] - times[:,i-1]
                unique_deltas = time_deltas.unique()
                sorted_unique_deltas, sorted_indices = unique_deltas.sort()
                indices = torch.searchsorted(sorted_unique_deltas, time_deltas)
                h = odeint(self.ode_func, h, sorted_unique_deltas).permute(1, 0, 2)
                gather_by_time = [h[i,indices[i]][None,...] for i in range(x.shape[0])]
                h = torch.cat(gather_by_time, dim=0)

            h = self.gru_cell(x[:, i, :], h)
            output = self.output_layer(h)
            outputs.append(output)

        return torch.stack(outputs, dim=1)
