import torch
import torch.nn as nn

class TimeAwareTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, num_layers, num_time_indices, time_embedding_dim):
        super(TimeAwareTransformer, self).__init__()

        # Time Embedding
        self.time_embedding = nn.Embedding(num_time_indices, time_embedding_dim)

        # Transformer Encoder with batch_first=True
        self.transformer_layer = nn.TransformerEncoderLayer(
            d_model=input_dim + time_embedding_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(self.transformer_layer, num_layers=num_layers)

        # Fully connected output layer
        self.fc = nn.Linear(input_dim + time_embedding_dim, output_dim)

    def forward(self, x, times):
        # Expecting input shape (batch_size, sequence_length, input_dim)
        batch_size, sequence_length, input_dim = x.size()

        # Time Embedding
        time_embeds = self.time_embedding(times)

        # Concatenate input and time embeddings
        rnn_input = torch.cat([x, time_embeds], dim=-1)  # (batch_size, sequence_length, input_dim + time_embedding_dim)

        # Transformer Encoder
        transformer_output = self.transformer_encoder(rnn_input)

        # Average over sequence length
        transformer_output = transformer_output.mean(dim=1)  # (batch_size, input_dim + time_embedding_dim)

        # Fully connected output layer
        output = self.fc(transformer_output)

        return output

class BidirectionalTimeAwareTransformer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, num_layers, num_time_indices, time_embedding_dim):
        super(BidirectionalTimeAwareTransformer, self).__init__()

        self.time_embedding = nn.Embedding(num_time_indices, time_embedding_dim)
        self.bidirectional_gru = nn.GRU(
            input_dim + time_embedding_dim, hidden_dim, batch_first=True, bidirectional=True
        )

        self.transformer_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim * 2,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 2,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(self.transformer_layer, num_layers=num_layers)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, x, times):

        batch_size, sequence_length, input_dim = x.size()
        times = times.expand(batch_size, sequence_length)
        time_embeds = self.time_embedding(times)
        rnn_input = torch.cat([x, time_embeds], dim=-1)
        gru_output, _ = self.bidirectional_gru(rnn_input)
        transformer_output = self.transformer_encoder(gru_output)
        transformer_output = transformer_output.mean(dim=1)

        output = self.fc(transformer_output)
        return output