import torch
import torch.nn as nn
from ..recurrent.grucell import GRUCell
from ..seq2seqbase import Seq2SeqAttrs

class EncoderModel(nn.Module, Seq2SeqAttrs):
    def __init__(self, conv_ru, conv_c, level_sizes, **model_kwargs):
        nn.Module.__init__(self)
        Seq2SeqAttrs.__init__(self, level_sizes, **model_kwargs)
        self.conv_ru = conv_ru
        self.conv_c = conv_c
        self.gru_layers = self.init_gru_layers()
        self.projection_layer = nn.Linear(self.input_dim, self.embed_size)
    
    def init_gru_layers(self):
        module_list = []
        for i in range(self.enc_layer_num):    
            module_list.append(
                GRUCell(
                        embed_size=self.embed_size,
                        conv_ru=self.conv_ru,
                        conv_c=self.conv_c
                    ))
        return nn.ModuleList(module_list)

    def forward(self, batch, t, hidden_state=None):
        
        inputs = batch.data[:,t,...]
        
        batch_size, node_num, _  = inputs.shape
        
        inputs = self.projection_layer(inputs)
        
        if hidden_state is None:
            hidden_state = torch.zeros((self.enc_layer_num, batch_size, node_num, self.embed_size))
            hidden_state = hidden_state.to(inputs)
            
        hidden_states = []
        output = inputs
        for layer_num, dcgru_layer in enumerate(self.gru_layers):
            next_hidden_state = dcgru_layer(output, batch, t, hidden_state[layer_num])
            hidden_states.append(next_hidden_state)
            output = next_hidden_state            

        return output, torch.stack(hidden_states)  # runs in O(num_layers) so not too slow