import torch
import torch.nn as nn


class RNNEnc(nn.Module):

    def __init__(self,
                 input_dim: int,
                 latent_dim: int,
                 bidirectional: bool = True,
                 **unused_kwargs):
        super(RNNEnc, self).__init__()

        self.enc = nn.Linear(input_dim, latent_dim)
        self.rnn = nn.GRU(latent_dim, latent_dim, batch_first=True,
                          bidirectional=bidirectional)
        self.dec = nn.Linear(2 * latent_dim, latent_dim)
        self.latent_dim = latent_dim

    def forward(self, obs):
        """
        :param obs: [ batch x time x input dim]
        :return: h: [ batch x latent_dim ]
        """
        batch_size = obs.shape[0]

        x = self.enc(obs)
        h0 = torch.zeros(2, batch_size, self.latent_dim, device=obs.device)
        _, h = self.rnn(x, h0)  # [2 x batch x latent_dim]
        h = h.transpose(1, 0).flatten(start_dim=1)
        h = self.dec(h)
        return h


# This was in model bb
class RNNEnc2(nn.Module):

    def __init__(self,
                 input_dim: int,
                 latent_dim: int,
                 output_dim: int,
                 bidirectional: bool = True,
                 **unused_kwargs):
        super().__init__()

        self.enc = nn.Linear(input_dim, latent_dim)
        self.rnn = nn.GRU(latent_dim, latent_dim, batch_first=True,
                          bidirectional=bidirectional)
        self.dec = nn.Linear(2 * latent_dim, output_dim)
        self.latent_dim = latent_dim

    def forward(self, obs):
        """
        :param obs: [ batch, time, input dim]
        :return: h: [ batch, latent_dim ]
        """
        batch_size = obs.shape[0]

        x = self.enc(obs)
        _, h = self.rnn(x)  # [2 x batch x latent_dim]
        h = h.transpose(1, 0).flatten(start_dim=1)
        h = self.dec(h)
        return h


class RNNEnc2d(nn.Module):
    """Encoder for the RNN model in 2d"""

    def __init__(self,
                 enc_config: dict,
                 bidirectional: bool = True,
                 **unused_kwargs):
        super().__init__()
        self.enc = nn.Linear(**enc_config)
        latent_dim = enc_config['out_features']
        # self.enc = nn.Sequential(nn.Conv2d(**enc_conv_config), nn.GELU(), nn.Linear(input_dim, latent_dim))
        self.rnn = nn.GRU(latent_dim, latent_dim, batch_first=True,
                          bidirectional=bidirectional)
        self.dec = nn.Linear(2 * latent_dim, latent_dim)
        # self.dec = nn.Sequential(nn.Linear(latent_dim, latent_dim), nn.GELU(), nn.Conv2d(**dec_conv_config))
        self.latent_dim = latent_dim

    def forward(self, obs):
        """
        :param obs: [ batch, time, w  x h ]
        :return: h: [ batch, latent_dim ]
        """
        batch_size = obs.shape[0]
        x = self.enc(
            obs.flatten(start_dim=-2))  # flatten the last two dimensions. NOTE: would be better to use some conv layers
        _, h = self.rnn(x)  # [2 x batch x latent_dim]
        h = h.transpose(1, 0).flatten(start_dim=1)
        h = self.dec(h)
        return h
