"""
This module contains positional encodings for the Material-to-Context (M2C) encoder.
"""

import sys
import os

sys.path.append(os.path.dirname(__file__))

import torch
import torch.nn as nn


class CumulativeDepthEncoding(nn.Module):
    """
    Implementation of the cumulative depth encoding for the Material-to-Context (M2C) encoder.
    It is intended to be added to the layer encodings.
    """

    def __init__(
        self,
        dim: int,
        max_seq_len: int = 5,
        ini_freq_scale: float = 1.0,
        tunable_freq_scale: bool = True,
        dropout: float = 0.0,
    ):
        """
        __init__ method for the CumulativeDepthEncoding class.

        Args:
            dim (int): The dimension of the encoding.
            max_seq_len (int): The maximum sequence length.
            ini_freq_scale (float): The initial frequency scale.
            tunable_freq_scale (bool): Whether to use a tunable frequency scale.
            dropout (float): The dropout rate.
        """
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.dropout = nn.Dropout(dropout)

        if tunable_freq_scale:
            self.freq_scale = nn.Parameter(torch.ones(1) * ini_freq_scale)
        else:
            self.freq_scale = ini_freq_scale

        div_term = torch.exp(
            torch.arange(0, dim, 2).float() * (-torch.log(torch.tensor(10000.0)) / dim)
        )
        self.register_buffer("div_term", div_term)

    def forward(
        self, thicknesses: torch.Tensor, cumulative: bool = False
    ) -> torch.Tensor:
        """
        Forward method for the CumulativeDepthEncoding class.

        Args:
            thicknesses (torch.Tensor): thicknesses (batch_size, max_seq_len).
            cumulative (bool): Whether the input is already cumulative.

        Returns:
            torch.Tensor: The output tensor of shape (batch_size, max_seq_len, dim).
        """
        assert thicknesses.shape[1] == self.max_seq_len, (
            "thicknesses should be of shape (batch_size, max_seq_len), please use padding."
        )

        if not cumulative:
            cumulative_depth = torch.cumsum(
                thicknesses, dim=1
            )  # (batch_size, max_seq_len)
        else:
            cumulative_depth = thicknesses

        cumulative_depth = (
            cumulative_depth.unsqueeze(-1) * self.freq_scale
        )  # (batch_size, max_seq_len, 1)

        # Sinusoidal encoding
        pe = torch.zeros_like(
            cumulative_depth.repeat(1, 1, self.dim)
        )  # (batch_size, max_seq_len, dim)
        pe[..., 0::2] = torch.sin(cumulative_depth * self.div_term)
        pe[..., 1::2] = torch.cos(cumulative_depth * self.div_term)

        return self.dropout(pe)


class TrainableCumulativeDepthEncoding(nn.Module):
    """
    Implementation of a trainable cumulative depth encoding for the Material-to-Context (M2C) encoder.
    It is intended to be added to the layer encodings.
    """

    def __init__(
        self,
        dim: int,
        max_seq_len: int = 5,
        ini_freq_scale: float = 1.0,
        tunable_freq_scale: bool = True,
        dropout: float = 0.0,
    ):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.dropout = nn.Dropout(dropout)

        if tunable_freq_scale:
            self.freq_scale = nn.Parameter(torch.ones(1) * ini_freq_scale)
        else:
            self.freq_scale = ini_freq_scale

        self.mlp = nn.Sequential(
            nn.Linear(1, 2 * dim),
            nn.GELU(),
            nn.LayerNorm(2 * dim),
            nn.Linear(2 * dim, dim),
        )

    def forward(self, thicknesses: torch.Tensor) -> torch.Tensor:
        assert thicknesses.shape[1] == self.max_seq_len, (
            "thicknesses should be of shape (batch_size, max_seq_len), please use padding."
        )

        cumulative_depth = (
            torch.cumsum(thicknesses, dim=1) * self.freq_scale
        )  # (batch_size, max_seq_len)

        cumulative_depth = self.mlp(
            cumulative_depth.unsqueeze(-1)
        )  # (batch_size, max_seq_len, dim)

        return cumulative_depth
