"""
This module contains sequence encoding components for the different architectures.
"""

import sys
import os

sys.path.append(os.path.dirname(__file__))

import torch
import torch.nn as nn
from pos_encodings import CumulativeDepthEncoding, TrainableCumulativeDepthEncoding


class ResidualConnection(nn.Module):
    """
    Implementation of the Residual Connection layer.
    """

    def __init__(self, features: int, dropout: float = 0.0):
        """
        ___init__ method for the ResidualConnection class.

        Add the input to the output of the sub-layer and apply layer normalization.

        Args:
            features (int): Number of features to normalize.
            dropout (float): Dropout rate.
        """
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(features)
        self.act = nn.GELU()

    def forward(self, x, sub_layer):
        """
        Forward pass for the ResidualConnection class.

        Args:
            x (torch.Tensor): Input tensor.
            sub_layer (callable): Sub-layer to apply to the input.
        """
        return self.norm(self.act(x + self.dropout(sub_layer(x))))


class FeedForwardBlock(nn.Module):
    """
    Implementation of the FeedForward Block layer.
    """

    def __init__(self, dim: int, d_ff: int, dropout: float):
        """
        ___init__ method for the FeedForwardBlock class.

        Two linear layers with an activation function in between.
        It is applied after the multi-head self-attention layer.

        Args:
            dim (int): Dimension of the embedding vectors.
            d_ff (int): Dimension of the feed forward layer.
            dropout (float): Dropout rate.
        """
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.activation = nn.GELU()

        self.linear_1 = nn.Linear(in_features=dim, out_features=d_ff, bias=True)
        self.linear_2 = nn.Linear(in_features=d_ff, out_features=dim, bias=True)

    def forward(self, x):
        """
        Forward pass for the FeedForwardBlock class.

        Args:
            x (torch.Tensor): Input tensor.
            [batch, seq_len, dim]

        Returns:
            torch.Tensor: Output tensor after applying the feedforward block.
            [batch, seq_len, dim]
        """
        return self.linear_2(self.dropout(self.activation(self.linear_1(x))))


class SequentialEncoderBlock(nn.Module):
    """
    Implementation of a Sequential Encoder Block layer.

    The block consists of a multi-head self-attention layer
    followed by a feedforward layer, with residual connections.
    """

    def __init__(
        self,
        dim: int,
        num_heads: int,
        d_ff: int,
        dropout: float = 0.0,
    ):
        """
        ___init__ method for the SequentialEncoderBlock class.

        Args:
            dim (int): Dimension of the embedding vectors.
            num_heads (int): Number of attention heads.
            d_ff (int): Dimension of the feed forward layer.
            dropout (float): Dropout rate.
        """
        super().__init__()

        self.self_attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True,
            bias=True,
        )
        self.ffn = FeedForwardBlock(dim=dim, d_ff=d_ff, dropout=dropout)
        self.residual = nn.ModuleList(
            [
                ResidualConnection(
                    features=dim,
                    dropout=dropout,
                )
                for _ in range(2)
            ]
        )

    def forward(self, x, src_mask):
        """
        Forward pass for the SequentialEncoderBlock class.

        src_mask is used to mask the padding values in the input.

        Args:
            x (torch.Tensor): Input tensor.
            [batch, seq_len, dim]
            src_mask (torch.Tensor): Source mask for attention.
            [batch, seq_len]

        Returns:
            torch.Tensor: Output tensor after applying the encoder block.
            [batch, seq_len, dim]
        """
        x = self.residual[0](
            x,
            lambda x: self.self_attn(
                x, x, x, key_padding_mask=src_mask, need_weights=False
            )[0],
        )
        return self.residual[1](x, self.ffn)


class SequentialEncoder(nn.Module):
    """
    Implementation of the Sequential Encoder.

    This encoder follows the idea of the encoding path used in Transformers.
    """

    def __init__(
        self,
        dim: int,
        out_dim: int,
        max_seq_len: int,
        num_heads: int,
        num_layers: int,
        d_ff: int,
        dropout: float = 0.1,
        kwargs_pe: dict = {
            "ini_freq_scale": 1.0,
            "tunable_freq_scale": True,
            "dropout": 0.0,
        },
        use_cls: bool = False,
        trainable_pe: bool = False,
    ):
        """
        ___init__ method for the MediaEncoder class.

        Args:
            dim (int): Dimension of the embedding vectors.
            max_seq_len (int): Maximum sequence length.
            num_heads (int): Number of attention heads.
            num_layers (int): Number of encoder layers.
            d_ff (int): Dimension of the feed forward layer.
            dropout (float): Dropout rate.
            kwargs_pe (dict): Additional arguments for positional encoding.

        Notes:
            dropout for pe should be set through kwargs_pe.
        """
        super().__init__()
        self.use_cls = use_cls

        self.norm = nn.LayerNorm(dim)

        if not trainable_pe:
            self.pe = CumulativeDepthEncoding(
                dim=dim, max_seq_len=max_seq_len, **kwargs_pe
            )
        else:
            self.pe = TrainableCumulativeDepthEncoding(
                dim=dim, max_seq_len=max_seq_len, **kwargs_pe
            )

        self.layers = nn.ModuleList(
            [
                SequentialEncoderBlock(
                    dim=dim,
                    num_heads=num_heads,
                    d_ff=d_ff,
                    dropout=dropout,
                )
                for _ in range(num_layers)
            ]
        )

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, out_dim * 2),
            nn.GELU(),
            nn.LayerNorm(out_dim * 2),
            nn.Linear(out_dim * 2, out_dim),
        )

        if self.use_cls:
            self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

    def forward(
        self,
        x: torch.Tensor,
        thicknesses: torch.Tensor,
        src_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Forward pass for the MediaEncoder class.

        Args:
            x (torch.Tensor): Input tensor (embeddings of the layers).
            [batch, seq_len, dim]
            thicknesses (torch.Tensor): Thicknesses tensor.
            [batch, seq_len]
            src_mask (torch.Tensor): Source mask for attention.
            [batch, seq_len]

        Returns:
            torch.Tensor: Output tensor after applying the encoder.
            [batch, seq_len, dim]

        Notes:
            A true value in src_mask indicates that the corresponding
            position is a padding token and should be ignored.
        """
        # Add the positional encoding to the input
        x = x + self.pe(thicknesses)

        # Add the CLS token if specified
        if self.use_cls:
            cls_tokens = self.cls_token.repeat(x.size(0), 1, 1)  # (batch, 1, dim)
            x = torch.cat((cls_tokens, x), dim=1)  # (batch, seq_len+1, dim)

            cls_mask = torch.zeros(x.size(0), 1, dtype=src_mask.dtype).to(
                src_mask.device
            )  # (batch, 1)
            src_mask = torch.cat((cls_mask, src_mask), dim=1)  # (batch, seq_len+1)

        for layer in self.layers:
            x = layer(x, src_mask)

        return self.mlp_head(self.norm(x))
