import torch
import torch.nn as nn


class RecurrentGRU(nn.Module):
    def __init__(self, input_dim, device=None, rnn_hidden_dim=128, rnn_layer_num=1):
        super().__init__()
        self.input_dim = input_dim
        self.rnn_hidden_dim = rnn_hidden_dim
        self.rnn_layer_num = rnn_layer_num
        self.device = device
        
        self.GRU = nn.GRU(
            input_dim, 
            rnn_hidden_dim, 
            rnn_layer_num, 
            batch_first=True
        )
        
    def forward(self, x, lens, pre_hidden=None):
        if pre_hidden is None:
            pre_hidden = self.zero_hidden(batch_size=x.shape[0])
        if len(pre_hidden.shape) == 2:
            pre_hidden = torch.unsqueeze(pre_hidden, dim=0)
            
        packed = torch.nn.utils.rnn.pack_padded_sequence(x, lens, batch_first=True, enforce_sorted=False)
        
        output, hidden = self.GRU(packed, pre_hidden)
        output, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        return output, hidden
    
    def zero_hidden(self, batch_size):
        return torch.zeros([self.rnn_layer_num, batch_size, self.rnn_hidden_dim]).to(self.device)

        
    def get_hidden(self, obs, last_actions, lens):
        # This is for compatable use
        pre_hidden = self.zero_hidden(len(lens))
        x = torch.cat([obs, last_actions], dim=-1)
        output, hidden = self(x, lens, pre_hidden=pre_hidden)
        return output
