
import torch
import torch.nn as nn

class EvalConfig:
    # Weights
    checkpoint_path = "yourpath/decoder_epoch250.pth"
    # val
    embed_dim   = 768
    vocab_size  = 49408
    max_seq_len = 77
    batch_size  = 256

    
class EmbeddingDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.embed_proj = nn.Linear(EvalConfig.embed_dim, EvalConfig.embed_dim)
        self.token_embed = nn.Embedding(EvalConfig.vocab_size, EvalConfig.embed_dim)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=EvalConfig.embed_dim, nhead=8,
            dim_feedforward=2048, dropout=0.1
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
        self.pos_encoder = nn.Embedding(EvalConfig.max_seq_len, EvalConfig.embed_dim)
        self.output = nn.Linear(EvalConfig.embed_dim, EvalConfig.vocab_size)

    def forward(self, src_embeds, tgt_tokens):
        # memory: [1, batch, D]
        memory = self.embed_proj(src_embeds).unsqueeze(0)
        # tgt_embeds: [batch, L, D]
        tgt_embeds = self.token_embed(tgt_tokens)
        L = tgt_tokens.size(1)
        positions = torch.arange(L, device=tgt_tokens.device)
        pos_embeds = self.pos_encoder(positions).unsqueeze(0)
        tgt = (tgt_embeds + pos_embeds).permute(1, 0, 2)  # [L, batch, D]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(L).to(tgt.device)
        out = self.decoder(tgt, memory, tgt_mask=tgt_mask)  # [L, batch, D]
        out = self.output(out)                              # [L, batch, V]
        return out.permute(1, 2, 0)                         # [batch, V, L]
