import torch
import torch.nn as nn

class T_VRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim, num_time_indices, time_embedding_dim):
        super(T_VRNN, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        self.time_embedding = nn.Embedding(num_time_indices, time_embedding_dim)

        self.encoder = nn.Sequential(
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.latent_dim * 2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, input_dim)
        )

        self.rnn = nn.GRU(input_dim + time_embedding_dim, hidden_dim, batch_first=True)

        self.predictor = nn.Linear(hidden_dim, output_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, times):
        batch_size, sequence_length, input_dim = x.size()

        time_embeds = self.time_embedding(times)

        rnn_input = torch.cat([x, time_embeds], dim=-1)

        h = torch.zeros(1, batch_size, self.hidden_dim).to(x.device)

        enc_output, enc_hidden = self.rnn(rnn_input, h)

        # Reparameterization trick
        enc_output = enc_output.view(batch_size * sequence_length, self.hidden_dim)

        mu, logvar = torch.chunk(self.encoder(enc_output), 2, dim=-1)
        z = self.reparameterize(mu, logvar)

        dec_output = self.decoder(z)
        dec_output = dec_output.view(batch_size, sequence_length, input_dim)

        out = self.predictor(enc_hidden[-1])

        return out, mu, logvar
class Bi_T_VRNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, output_dim, num_time_indices, time_embedding_dim):
        super(Bi_T_VRNN, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        self.time_embedding = nn.Embedding(num_time_indices, time_embedding_dim)

        self.encoder = nn.Sequential(
            nn.Linear(self.hidden_dim * 2, self.hidden_dim),  # Adjusted for bidirectional GRU
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.latent_dim * 2)
        )

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, self.hidden_dim),  # Adjusted input size
            nn.ReLU(),
            nn.Linear(self.hidden_dim, input_dim)
        )

        self.rnn = nn.GRU(input_dim + time_embedding_dim, hidden_dim, bidirectional=True, batch_first=True)

        self.predictor = nn.Linear(hidden_dim * 2, output_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x, times):
        batch_size, sequence_length, input_dim = x.size()

        time_embeds = self.time_embedding(times)

        rnn_input = torch.cat([x, time_embeds], dim=-1)

        h = torch.zeros(2, batch_size, self.hidden_dim).to(x.device)  # Adjusted for bidirectional GRU

        enc_output, enc_hidden = self.rnn(rnn_input, h)
        enc_output = enc_output.view(batch_size * sequence_length, self.hidden_dim * 2)  # Reshape

        mu, logvar = torch.chunk(self.encoder(enc_output), 2, dim=-1)
        z = self.reparameterize(mu, logvar)

        dec_output = self.decoder(z)
        dec_output = dec_output.view(batch_size, sequence_length, input_dim)  # Reshape

        out = self.predictor(enc_hidden[-1])

        return out, mu, logvar