import torch as th
import torch.nn as nn
import torch.nn.functional as F


class RNNDecoder(nn.Module):
    def __init__(self, args):
        super(RNNDecoder, self).__init__()
        self.args = args
        input_dim, output_dim = self._get_shapes(args)
        self.fc1 = nn.Linear(input_dim, args.dec_emb)
        self.rnn = nn.GRUCell(args.dec_emb, args.dec_emb)
        self.fc2 = nn.Linear(args.dec_emb + args.enc_emb, args.dec_emb)

        self.output_layer = nn.Linear(args.dec_emb, output_dim)


    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.dec_emb).zero_()

    def forward(self, encoded_z, state_input, hidden_state):
        """
        encoded_z.shape: [bs, z_dim]
        state_input.shape: [bs, state_dim]
        hidden_state.shape: [bs, h_dim]
        """
        x = F.relu(self.fc1(state_input))
        hidden_state = self.rnn(x, hidden_state)

        x = th.cat([hidden_state, encoded_z], dim=1)
        x = F.relu(self.fc2(x))
        
        decoded_output = self.output_layer(x)
        return decoded_output, hidden_state

    def _get_shapes(self, args):
        """
        return the input_shape and output_shape
        """
        return args.state_shape, args.state_shape
    