import torch
import torch.nn as nn

class TransformerDecoderWithProjection(nn.Module):
    def __init__(
        self,
        d_model=256,               
        nhead=8,                   
        num_layers=6,              
        memory_dim=768,          
        input_dim = 1024,
        dim_feedforward=2048,   
        dropout=0.1            
    ):
        super().__init__()

        self.input_projection = nn.Linear(input_dim, d_model)
        self.memory_projection = nn.Linear(memory_dim, d_model)


        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True  
        )

        self.decoder = nn.TransformerDecoder(
            decoder_layer=decoder_layer,
            num_layers=num_layers
        )

    def forward(
        self,
        vert_feature: torch.Tensor,                         # shape: (batch_size, 5023, 256)
        image_feature: torch.Tensor                       # shape: (batch_size, 1369, 768)
    ) -> torch.Tensor:

        projected_vert_feature = self.input_projection(vert_feature)
        projected_memory = self.memory_projection(image_feature)  # shape: (batch_size, 1369, 256)


        output = self.decoder(
            tgt=projected_vert_feature,
            memory=projected_memory
        )
        return output