import torch
import torch.nn as nn

class VRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(VRNN, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.latent_dim * 2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.rnn = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.predictor = nn.Linear(hidden_dim, output_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        batch_size, sequence_length, input_dim = x.size()

        h = torch.zeros(1, batch_size, self.hidden_dim).to(x.device)

        enc_output, enc_hidden = self.rnn(x, h)
        enc_hidden = enc_hidden.view(batch_size, self.hidden_dim)

        # Reparameterization trick
        mu_logvar = self.encoder(enc_hidden)
        mu, logvar = torch.chunk(mu_logvar, 2, dim=-1)
        z = self.reparameterize(mu, logvar)

        dec_output = self.decoder(z)

        out = self.predictor(dec_output)

        return out, mu, logvar
class BiVRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim):
        super(BiVRNN, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),  # Hidden dim * 2 because of the bi-directional nature
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.latent_dim * 2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        self.rnn = nn.GRU(input_dim, hidden_dim, bidirectional=True, batch_first=True)
        self.predictor = nn.Linear(hidden_dim * 2, output_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        batch_size, sequence_length, input_dim = x.size()

        h = torch.zeros(2, batch_size, self.hidden_dim).to(x.device)  # Initialize hidden state for both directions

        enc_output, enc_hidden = self.rnn(x, h)
        enc_hidden = enc_hidden.view(batch_size, self.hidden_dim * 2)  # Reshape to combine the two directional states

        mu_logvar = self.encoder(enc_hidden)
        mu, logvar = torch.chunk(mu_logvar, 2, dim=-1)
        z = self.reparameterize(mu, logvar)

        dec_output = self.decoder(z)

        out = self.predictor(dec_output)

        return out, mu, logvar