import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, num_layers):
        super(TransformerModel, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads,
                                                        dim_feedforward=hidden_dim, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=input_dim, nhead=num_heads,
                                                        dim_feedforward=hidden_dim, batch_first=True)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        memory = self.encoder(src, src_key_padding_mask=src_mask)
        output = self.decoder(tgt, memory, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask)
        output = output.mean(dim=1)
        output = self.fc(output)
        return output
class BidirectionalTransformerModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_heads, num_layers):
        super(BidirectionalTransformerModel, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=num_heads,
                                                        dim_feedforward=hidden_dim, batch_first=True)
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=input_dim, nhead=num_heads,
                                                        dim_feedforward=hidden_dim, batch_first=True)
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(input_dim * 2, output_dim)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        memory = self.encoder(src, src_key_padding_mask=src_mask)
        reverse_memory = self.encoder(torch.flip(src, [1]), src_key_padding_mask=src_mask)
        memory = torch.cat((memory, reverse_memory), dim=2)
        output = self.decoder(tgt, memory, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask)
        output = output.mean(dim=1)
        output = self.fc(output)
        return output