import torch.nn as nn
from einops import rearrange
from csiva.model.multi_head_self_attention import MHSA

###############################################
# -- Stack of alternating attention blocks -- #
###############################################
class AlternatingAttentionStack(nn.Module):
    """Stacked layer of alternating attention blocks.

    AlternatingAttentionStack is used to instantiate the encoder of the NPT.

    Parameters
    ----------
    num_layers : int
        Number of stacked alternating attention blocks.
    d_model: int
        Dimension of the key, query, value input embeddings.
    dim_feedforward : int
        Hidden dimension of the MLP after MHSA in the encoder layer.
    num_heads : int
        Number of heads in multihead self-attention.
    eps_layer_norm: float
        LayerNorm epsilon in multi-head attention in the alternate attention blocks.
        Required for numerical stability.
    p_dropout : float
        Dropout probability on the `attn_output_weights`.
    encoder_layer_type: str
        Specify which class to use for the encoder layer. Use "custom" for csiva.model.MHSA,
        or "torch" for torch.nn.TransformerEncoderLayer.
    rff_depth: int, default 1
        Number of feed-forward layers in the MLP of the encoder layer in alternate attention.
    """
    def __init__(
        self,
        num_layers: int,
        d_model: int,
        dim_feedforward: int,
        num_heads: int,
        eps_layer_norm: float,
        p_dropout: float,
        encoder_layer_type: str,
        rff_depth: int = 1
    ):
        super(AlternatingAttentionStack, self).__init__()

        # *** Define the architecture ***
        self.stack = nn.ModuleList([
            AlternatingAttentionBlock(
                d_model,
                dim_feedforward,
                num_heads,
                eps_layer_norm,
                p_dropout,
                encoder_layer_type,
                rff_depth
            ) for _ in range(num_layers)
        ])
        
    def forward(
        self,
        X,
        attribute_mask=None,
        datapoints_mask=None
    ):
        for alternate_attn_block in self.stack:
            X = alternate_attn_block(X, attribute_mask, datapoints_mask)
        
        return X
    


###########################################
# -- Alternating attention single block-- #
###########################################
class AlternatingAttentionBlock(nn.Module):
    """Block alternating attention between datapoints and attributes.

    Parameters
    ----------
    d_model: int
        Dimension of the key, query, value input embeddings.
    dim_feedforward : int
        Hidden dimension of the MLP after MHSA in the encoder layer.
        TODO: replace with widening_factor applied to d_model
    num_heads : int
        Number of heads in multihead self-attention.
    eps_layer_norm: float
        LayerNorm epsilon in multi-head attention in the alternate attention blocks.
        Required for numerical stability.
    p_dropout : float
        Dropout probability on the `attn_output_weights`.
    encoder_layer_type: str
        Specify which class to use for the encoder layer. Use "custom" for csiva.model.MHSA,
        or "torch" for torch.nn.TransformerEncoderLayer.
    rff_depth: int, default 1
        Number of feed-forward layers in the MLP of the encoder layer in alternate attention.
    """
    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        num_heads: int,
        eps_layer_norm: float,
        p_dropout: float,
        encoder_layer_type: str,
        rff_depth: int = 1
    ):
        super(AlternatingAttentionBlock, self).__init__()
        self.embed_dim = d_model
        self.dim_feedforward = dim_feedforward
        self.num_heads = num_heads
        self.eps_layer_norm = eps_layer_norm
        self.p_dropout = p_dropout    
        self.encoder_layer_type = encoder_layer_type    

        # *** Define the architecture ***
        if encoder_layer_type == "torch":
            self.stack = nn.ModuleList([
                nn.TransformerEncoderLayer(
                    self.embed_dim, num_heads, self.dim_feedforward, p_dropout, nn.GELU(), eps_layer_norm, batch_first=True
                ), # abd
                nn.TransformerEncoderLayer(
                    self.embed_dim, num_heads, self.dim_feedforward, p_dropout, nn.GELU(), eps_layer_norm, batch_first=True
                ), # aba
            ])
        elif encoder_layer_type == "custom":
            self.stack = nn.ModuleList([
                MHSA(self.embed_dim, num_heads, eps_layer_norm, p_dropout, rff_depth=rff_depth, batch_first=True), # abd
                MHSA(self.embed_dim, num_heads, eps_layer_norm, p_dropout, rff_depth=rff_depth, batch_first=True) # aba
            ])
        else:
            raise ValueError("The accepted values for encoder_layer_type input arguments are 'custom', 'torch"\
                             f" Got instead {encoder_layer_type}.")


    def forward(
        self,
        X,
        attribute_mask=None,
        datapoints_mask=None
    ):
        _, D, N, _ = X.shape
        reshape_layers = [ABDReshape(), ABAReshape(D)]
        # Reshape the datapoints mask to match ABD input shape
        datapoints_mask = rearrange(datapoints_mask, 'b d n -> (b d) n')
        masks = [datapoints_mask, attribute_mask]

        # Attentions
        for i in range(2):
            attn_module = self.stack[i]
            mask = masks[i]
            reshape_layer = reshape_layers[i]
            if self.encoder_layer_type == "torch":
                X = attn_module(reshape_layer(X), src_key_padding_mask=mask)
            if self.encoder_layer_type == "custom":
                X = attn_module(reshape_layer(X), key_padding_mask=mask)
        
        # Output of shape B x D x N x E
        output = OriginalShape(N)(X)
        return output

###########################
# -- Reshape utilities -- #
###########################
class ABDReshape(nn.Module): 
    """Reshapes a tensor of shape (B, D, N, E) to (B*D, N, E).
    Used in Attention Between Datapoints (ABD).
    """
    def __call__(self, X):
        return rearrange(X, 'b d n e -> (b d) n e')

class ABAReshape(nn.Module):
    """Reshapes a tensor of shape (B*D, N, E) to (B*N, D, E).
    Used in Attention Between Attributes (ABA).
    """
    def __init__(self, D):
        self.D = D
    
    def __call__(self, X):
        return rearrange(X, '(b d) n e -> (b n) d e', d = self.D)
    
class OriginalShape(nn.Module):
    """Reshapes a tensor of shape (B*N, D, E) to (B, D, N, E).
    Restore input original shape.
    """
    def __init__(self, N):
        self.N = N
    
    def __call__(self, X):
        return rearrange(X, '(b n) d e -> b d n e', n = self.N)
    