import torch
from architectures.recurrent.rnn_module_custom import RNNModule

class RNN(torch.nn.Module):
    def __init__(self,
        device,
        scene_dim, gridcells_dim, output_dim,
        latent_dim=500,
        nonlinearity='sigmoid',
        dropouts=[0,0,0],
        bias=False
    ):
        super().__init__()

        # ENCODER

        self.gc2hidden = torch.nn.Linear(gridcells_dim, latent_dim)

        self.rnn = RNNModule(
            device,
            scene_dim, latent_dim,
            nonlinearity=nonlinearity,
            dropouts=dropouts,
            input_bias=bias, hidden_bias=bias
        )

        # DECODER
        
        self.decoder_lin = torch.nn.Linear(latent_dim, output_dim, bias=bias)

        self.add_do = dropouts[-1] > 0
        self.decoder_do = torch.nn.Dropout(dropouts[-1])

    def encode(self, scene, gc, hidden):
        """ Encodes the input tensor into a latent representation.

        Args:
            x: [BATCH SIZE, TIME, CHANNELS, HEIGHT, WIDTH]
        """
        
        if hidden is not None:
            return self.rnn(scene, hidden[None, ...])[0]
        else:
            # only grab the first instance of gridcells to init hidden state
            hidden = self.gc2hidden(gc[:, 0, ...])
            return self.rnn(scene, hidden)[0]

    def decode(self, x):
        decoded_lin = self.decoder_lin(x)
        if self.add_do:
            decoded_lin = self.decoder_do(decoded_lin)
        
        return decoded_lin
    
    def forward(self, scene, gc, hidden=None):
        hidden_all = self.encode(scene, gc, hidden)

        output = self.decode(hidden_all)

        return output, hidden_all, hidden_all[:,-1,:]