import torch
import torch.nn as nn
from ..seq2seqbase import Seq2SeqAttrs

from .operatornet import OperatorNet
from torchdiffeq import odeint_adjoint
from torchdiffeq import odeint as odeint_normal

class IntDecoderModel(nn.Module, Seq2SeqAttrs):
    def __init__(self, level_sizes, **model_kwargs):
        nn.Module.__init__(self)
        Seq2SeqAttrs.__init__(self, level_sizes, **model_kwargs)
        assert self.horizon % self.dec_layer_num == 0
        
        self.net = nn.ModuleList([OperatorNet(level_sizes, **model_kwargs) for layer_i in range(self.dec_layer_num)])
        
        self.projection_layer = nn.Linear(self.embed_size, self.output_dim)
        
    def forward(self, eval_t, conditional_state, edge_idx_all, location, t_past, hidden_state=None):
        assert (len(t_past) == conditional_state.shape[1])
        if hidden_state is None:
            hidden_state = torch.zeros_like(conditional_state)[:,0,...]
        
        hidden_states = []
        
        for layer_i in range(self.dec_layer_num):
            
            idxes = [ii+layer_i*self.horizon // self.dec_layer_num for ii in range(self.horizon // self.dec_layer_num)]
            term_t = (eval_t.unsqueeze(1) == torch.tensor(idxes).to(eval_t)).nonzero(as_tuple=False)[:,0].float()
            states = (term_t, hidden_state, conditional_state, edge_idx_all, location, t_past)
            hidden_state = self.net[layer_i](*states)
            hidden_states.append(hidden_state)
        
        hidden_states = torch.cat(hidden_states, dim=1)# sum_pooling
                
        return hidden_states
        
    