import torch
import torch.nn as nn
from einops import rearrange

from src.modules.mp_attention import SelfAttention, CrossAttention
from src.modules.mlp import MLP
from src.modules.embedder import PositionalEmbedding


class Decoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.decoder = nn.ModuleDict(dict(
            pe=PositionalEmbedding(config.d_model),
            drop=nn.Dropout(config.dropout),
            layers=nn.ModuleList([DecoderBlock(config) for _ in range(config.n_decoder_layers)]),
            norm=nn.LayerNorm(config.d_model)
        ))


    def forward(self, x, enc_output=None):

        x = self.decoder.drop(x)
        for layer in self.decoder.layers:
            x = layer(x, enc_output)
        return self.decoder.norm(x)


class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.norm1 = nn.LayerNorm(config.d_model)
        self.attention = SelfAttention(config, is_causal=True)
        if config.has_encoder:
            self.norm2 = nn.LayerNorm(config.d_model)
            self.norm3 = nn.LayerNorm(config.d_model)
            self.cross_attention = CrossAttention(config)
        self.norm4 = nn.LayerNorm(config.d_model)
        self.mlp = MLP(config)


    def forward(self, x, memory=None):
        if self.config.return_attention:
            y, self_attention = self.attention(self.norm1(x))
            x = x + y
            if self.config.has_encoder:
                y, cross_attention = self.cross_attention(self.norm2(x), self.norm3(memory))
                x = x + y
        else:
            x = x + self.attention(self.norm1(x))
            if self.config.has_encoder:
                x = x + self.cross_attention(self.norm2(x), self.norm3(memory), self.norm3(memory))
        x = x + self.mlp(self.norm4(x))
        return x
