import torch
import torch.nn as nn
from ..seq2seqbase import Seq2SeqAttrs
from .odebase import ODEAttrs
from .odefunc import ODEfunc
from .operatornet import OperatorNet
from torchdiffeq import odeint_adjoint
from torchdiffeq import odeint as odeint_normal

class ODEDecoderModel(nn.Module, Seq2SeqAttrs, ODEAttrs):
    def __init__(self, level_sizes, **model_kwargs):
        nn.Module.__init__(self)
        Seq2SeqAttrs.__init__(self, level_sizes, **model_kwargs)
        ODEAttrs.__init__(self, **model_kwargs)
        self.diffeq = OperatorNet(level_sizes, **model_kwargs)
        self.odefunc = ODEfunc(self.diffeq)
        
    def forward(self, eval_t, conditional_state, edge_idx_all, location, t_past, hidden_state=None):
        assert (len(t_past) == conditional_state.shape[1])
        if hidden_state is None:
            hidden_state = torch.zeros_like(conditional_state)[:,0,...]
        
        states = (hidden_state, conditional_state, edge_idx_all, location, t_past)
        
        self.odefunc.before_odeint()
        odeint = odeint_adjoint if self.use_adjoint else odeint_normal
        
        state_t = odeint(
                self.odefunc,
                states,
                eval_t,
                atol=self.atol,
                rtol=self.rtol,
                method=self.solver
            )
        states = state_t[0].transpose(0,1).contiguous()
        
        return states
        
    