import torch
import torch.nn as nn
from .grucell import GRUCell
from ..seq2seqbase import Seq2SeqAttrs
from ..operator.operatornet import OperatorNet

class RecDecoderModel(nn.Module, Seq2SeqAttrs):
    
    def __init__(self, level_sizes, **model_kwargs):
        nn.Module.__init__(self)
        Seq2SeqAttrs.__init__(self, level_sizes, **model_kwargs)
        self.net = OperatorNet(level_sizes, **model_kwargs)
        self.projection_layer = nn.Linear(self.embed_size, self.output_dim)
    

    def forward(self, eval_t, conditional_state, edge_idx_all, location, t_past, hidden_state=None):
        assert (len(t_past) == conditional_state.shape[1])

        if hidden_state is None:
            hidden_state = torch.zeros_like(conditional_state)[:,0,...]
        
        hidden_states = []
        for t in eval_t:
            state = (t, hidden_state, conditional_state, edge_idx_all, location, t_past)
            hidden_state = self.net(*state)
            hidden_states.append(hidden_state)
            conditional_state = torch.cat([conditional_state,hidden_state[:,None,...]], dim=1)
            t_past = torch.cat([t_past,torch.tensor([t_past[-1] + 1]).to(t_past)])
        
        hidden_states = torch.stack(hidden_states, dim=1)

        return self.projection_layer(hidden_states) 