import math

import torch
from torch import Tensor, nn
from torch.nn import Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

AGGREGATION_VARIANTS = [
    "q_mean_c",
    "self_attn_all",
]

ATTENTION_VARIANTS = [
    "standard",
    "akt_monotonic",
    "alibi_monotonic",
    "alibi_monotonic_q_k",
    "learnable_alibi_monotonic",
    "learnable_alibi_monotonic_q_k",
]


class KTST(nn.Module):
    def __init__(
        self,
        num_questions: int,
        num_concepts: int,
        max_concepts: int,
        max_len: int,
        d_model: int,
        nhead_tf: int,
        nhead_agg: int,
        num_layers_tf: int,
        num_layers_agg: int,
        dim_feedforward: int,
        dropout: float,
        dim_classifier: int,
        aggregation: str,
        use_bias_emb: bool,
        use_zero_init: bool,
        attn_variant: str,
    ):
        super().__init__()

        self.model_name = "ktst"
        self.emb_type = "qid"

        assert aggregation in AGGREGATION_VARIANTS
        assert attn_variant in ATTENTION_VARIANTS

        self.num_questions = num_questions
        self.num_concepts = num_concepts
        self.max_concepts = max_concepts
        self.max_len = max_len
        self.aggregation = aggregation
        self.attn_variant = attn_variant
        self.use_bias_emb = use_bias_emb
        self.use_zero_init = use_zero_init

        # embeddings
        self.q_emb = nn.Embedding(num_questions + 1, d_model)
        self.c_emb = nn.Embedding(num_concepts + 1, d_model)
        self.r_emb = nn.Embedding(2 + 1, d_model)
        if self.use_bias_emb:
            self.b_emb = nn.Embedding(1, d_model)

        # add positional encoding when using standard attention
        if self.attn_variant == "standard":
            self.positional_encoding = PositionalEncoding(
                d_model=d_model, max_len=max_len, dropout=dropout
            )

        # set aggregation
        if self.aggregation == "q_mean_c":
            self.agg = AggregationMean(max_concepts=max_concepts)
        elif self.aggregation == "self_attn_all":
            self.agg = AggregationSelfAttention(
                embed_dim=d_model,
                num_heads=nhead_agg,
                num_layers=num_layers_agg,
                dropout=dropout,
                use_questions=True,
            )

        self.model = Transformer(
            d_model=d_model,
            nheads=nhead_tf,
            num_encoder_layers=num_layers_tf,
            num_decoder_layers=num_layers_tf,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            attn_variant=attn_variant,
        )

        # final classifier
        self.clf = nn.Sequential(
            nn.Linear(d_model + d_model, dim_classifier),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(dim_classifier, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, 1),
        )

        self.loss_fn = nn.BCEWithLogitsLoss(reduction="mean")

        if self.use_zero_init:
            self._reset_parameters()

    def _reset_parameters(self) -> None:
        torch.nn.init.constant_(self.q_emb.weight, 0.0)

    def forward(self, data: Tensor, padding_mask: Tensor):
        questions, responses, concepts = self._unpack_data(data, padding_mask)

        # embeddings
        questions_embedded = self.q_emb.forward(questions + 1)
        concepts_embedded = self.c_emb.forward(concepts + 1)
        responses_embedded = self.r_emb.forward(responses + 1)

        # aggregation
        x = self.agg.forward(
            mask=(concepts == -1),
            questions_embedded=questions_embedded,
            concepts_embedded=concepts_embedded,
        )
        y = x + responses_embedded

        if self.use_bias_emb:
            x = x + self.b_emb.weight[0]
            y = y + self.b_emb.weight[0]

        if self.attn_variant == "standard":
            x = self.positional_encoding(x)
            y = self.positional_encoding(y)

        x_tilde = self.model.forward(tgt=x, src=y)

        clf_input = torch.cat([x_tilde, x], dim=-1)
        logits = self.clf.forward(clf_input).squeeze(-1)

        loss = self.loss_fn(
            torch.masked_select(logits[:, 1:], padding_mask[:, 1:]),
            torch.masked_select(
                responses[:, 1:].type(dtype=torch.float32), padding_mask[:, 1:]
            ),
        )

        return torch.sigmoid(logits), loss, 0.0

    def _unpack_data(
        self, data: Tensor, padding_mask: Tensor
    ) -> tuple[Tensor, Tensor, Tensor]:
        data = data.clone()
        data[~padding_mask] = 0  # change padded entries to `0` (our default is `-1`)

        questions = data[:, :, -2:-1].type(dtype=torch.int64).squeeze(dim=-1)
        responses = data[:, :, -1:].type(dtype=torch.int64).squeeze(dim=-1)
        concepts = data[:, :, : self.max_concepts].type(dtype=torch.int64)

        return questions, responses, concepts


class Transformer(nn.Module):
    def __init__(
        self,
        d_model: int,
        nheads: int,
        num_encoder_layers: int,
        num_decoder_layers: int,
        dim_feedforward: int,
        dropout: float,
        attn_variant: str,
    ):
        super().__init__()

        self.encoder = TransformerEncoder(
            d_model=d_model,
            num_layers=num_encoder_layers,
            nheads=nheads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            attn_variant=attn_variant,
        )

        self.decoder = TransformerDecoder(
            d_model=d_model,
            num_layers=num_decoder_layers,
            nheads=nheads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            attn_variant=attn_variant,
        )

    def forward(self, tgt: Tensor, src: Tensor) -> Tensor:

        B, S, d = tgt.shape
        assert src.shape == (B, S, d)

        attn_mask = self._generate_autoregressive_attn_mask(S=S, device=src.device)

        mem = self.encoder.forward(src=src, attn_mask=attn_mask)
        mem = self._shift_with_start_token(mem)

        tgt = self.decoder.forward(tgt=tgt, mem=mem, attn_mask=attn_mask)

        return tgt

    def _generate_autoregressive_attn_mask(self, S: int, device: torch.device):
        attn_mask = torch.triu(
            torch.full((S, S), float("-inf"), device=device),
            diagonal=1,
        )
        return attn_mask

    def _shift_with_start_token(self, x: Tensor) -> Tensor:
        B, S, d = x.shape
        mem_start_token = torch.full((B, 1, d), -1, device=x.device)
        x = torch.cat([mem_start_token, x], dim=1)[:, :-1]
        assert x.shape == (B, S, d)
        return x


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_layers: int,
        nheads: int,
        dim_feedforward: int,
        dropout: float,
        attn_variant: str,
    ):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(
                    d_model=d_model,
                    dim_feedforward=dim_feedforward,
                    dropout=dropout,
                    nheads=nheads,
                    attn_variant=attn_variant,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, src: Tensor, attn_mask: Tensor) -> Tensor:
        for layer in self.layers:
            src = layer(src=src, attn_mask=attn_mask)
        return src


class TransformerDecoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        num_layers: int,
        nheads: int,
        dim_feedforward: int,
        dropout: float,
        attn_variant: str,
    ):
        super().__init__()

        self.layers = nn.ModuleList(
            [
                TransformerDecoderLayer(
                    d_model=d_model,
                    dim_feedforward=dim_feedforward,
                    dropout=dropout,
                    nheads=nheads,
                    attn_variant=attn_variant,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(self, tgt: Tensor, mem: Tensor, attn_mask: Tensor) -> Tensor:
        for layer in self.layers:
            tgt = layer(tgt=tgt, mem=mem, attn_mask=attn_mask)
        return tgt


class TransformerEncoderLayer(nn.Module):

    def __init__(
        self,
        d_model: int,
        nheads: int,
        dim_feedforward: int,
        dropout: float,
        attn_variant: str,
    ) -> None:
        super().__init__()

        self.self_attn = BaseMultiheadAttention(
            embed_dim=d_model,
            num_heads=nheads,
            dropout=dropout,
            attn_variant=attn_variant,
        )

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = F.relu

    def forward(self, src: Tensor, attn_mask: Tensor) -> Tensor:

        src = src + self._sa_block(x=src, attn_mask=attn_mask)
        src = self.norm1(src)

        src = src + self._ff_block(x=src)
        src = self.norm2(src)

        return src

    def _sa_block(self, x: Tensor, attn_mask: Tensor | None) -> Tensor:
        x, *_ = self.self_attn(query=x, key=x, value=x, attn_mask=attn_mask)
        return self.dropout1(x)

    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


class TransformerDecoderLayer(nn.Module):
    def __init__(
        self,
        d_model: int,
        dim_feedforward: int,
        nheads: int,
        dropout: float,
        attn_variant: str,
    ):
        super().__init__()

        self.is_std_key = attn_variant not in (
            "akt_monotonic",
            "alibi_monotonic_q_k",
            "learnable_alibi_monotonic_q_k",
        )

        self.self_attn = BaseMultiheadAttention(
            embed_dim=d_model,
            num_heads=nheads,
            dropout=dropout,
            attn_variant=attn_variant,
        )
        self.cross_attn = BaseMultiheadAttention(
            embed_dim=d_model,
            num_heads=nheads,
            dropout=dropout,
            attn_variant=attn_variant,
        )

        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)

        self.activation = F.relu

    def forward(self, tgt: Tensor, mem: Tensor, attn_mask: Tensor) -> Tensor:

        tgt = tgt + self._sa_block(tgt=tgt, attn_mask=attn_mask)
        tgt = self.norm1(tgt)

        tgt = tgt + self._ca_block(tgt=tgt, mem=mem, attn_mask=attn_mask)
        tgt = self.norm2(tgt)

        tgt = tgt + self._ff_block(tgt)
        tgt = self.norm3(tgt)

        return tgt

    def _sa_block(self, tgt: Tensor, attn_mask: Tensor | None) -> Tensor:
        x, *_ = self.self_attn(query=tgt, key=tgt, value=tgt, attn_mask=attn_mask)
        return self.dropout1(x)

    def _ca_block(self, tgt: Tensor, mem: Tensor, attn_mask: Tensor | None) -> Tensor:
        tgt, *_ = self.cross_attn(
            query=tgt,
            key=mem if self.is_std_key else tgt,
            value=mem,
            attn_mask=attn_mask,
        )
        return self.dropout2(tgt)

    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout3(x)


class AggregationMean(nn.Module):
    def __init__(self, max_concepts: int) -> None:
        super().__init__()
        self.max_concepts = max_concepts

    def forward(
        self, mask: Tensor, questions_embedded: Tensor, concepts_embedded: Tensor
    ) -> Tensor:
        if self.max_concepts == 1:
            concepts_embedded = concepts_embedded.squeeze(dim=2)
            assert concepts_embedded.shape == questions_embedded.shape

            return questions_embedded + concepts_embedded
        else:
            concepts_embedded[mask] = 0
            denominator = (~mask).sum(dim=2).unsqueeze(-1)
            concepts_embedded_mean = concepts_embedded.sum(dim=2) / denominator

            return questions_embedded + concepts_embedded_mean


class AggregationWeightedMean(nn.Module):

    def __init__(self, embed_dim: int, max_concepts: int) -> None:
        super().__init__()
        self.max_concepts = max_concepts

        self.learnable_query = nn.Embedding(num_embeddings=1, embedding_dim=embed_dim)
        torch.nn.init.constant_(self.learnable_query.weight, 0.0)

    def forward(
        self, mask: Tensor, questions_embedded: Tensor, concepts_embedded: Tensor
    ) -> Tensor:
        if self.max_concepts == 1:
            concepts_embedded = concepts_embedded.squeeze(dim=2)
            assert concepts_embedded.shape == questions_embedded.shape

            return questions_embedded + concepts_embedded
        else:
            B, S, f, d = concepts_embedded.shape

            mask = mask.reshape(B * S, f)
            concepts_embedded = concepts_embedded.reshape(B * S, f, d)
            query_token = self.learnable_query.weight.reshape(1, 1, -1).repeat(
                B * S, 1, 1
            )

            assert mask.shape == (B * S, f)
            assert query_token.shape == (B * S, 1, d)

            # mask out irrelevant input
            timestamp_has_entry = torch.any(~mask, dim=1)
            assert (
                timestamp_has_entry == True
            ).all()  # TODO: if so, then this is not needed
            query_token = query_token[timestamp_has_entry]
            concepts_embedded = concepts_embedded[timestamp_has_entry]
            mask = mask[timestamp_has_entry]

            attn_mask = F._canonical_mask(
                mask=mask,
                mask_name="key_padding_mask",
                other_type=None,
                other_name="attn_mask",
                target_type=query_token.dtype,
            )

            assert attn_mask is not None
            assert attn_mask.dim() > 1

            attn_mask = attn_mask.unsqueeze(1)

            q_scaled = query_token / math.sqrt(d)
            attn_output_weights = torch.baddbmm(
                attn_mask, q_scaled, concepts_embedded.transpose(-2, -1)
            )
            attn_output_weights = F.softmax(attn_output_weights, dim=-1)

            output = torch.zeros((B * S, d), device=concepts_embedded.device)
            output[timestamp_has_entry] = torch.bmm(
                attn_output_weights, concepts_embedded
            ).squeeze(1)
            output = output.reshape(B, S, d)

            assert output.shape == (B, S, d)

            return questions_embedded + output


class AggregationSelfAttention(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        num_layers: int,
        dropout: float,
        use_questions: bool,
    ) -> None:
        super().__init__()

        self.use_questions = use_questions

        self.learnable_query = nn.Embedding(num_embeddings=1, embedding_dim=embed_dim)
        torch.nn.init.constant_(self.learnable_query.weight, 0.0)

        self.attn_blocks = nn.ModuleList(
            [
                nn.MultiheadAttention(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    dropout=dropout,
                    batch_first=True,
                )
                for _ in range(num_layers)
            ]
        )

    def forward(
        self, mask: Tensor, concepts_embedded: Tensor, questions_embedded: Tensor
    ) -> Tensor:
        B, S, f, d = concepts_embedded.shape

        mask = mask.reshape(B * S, f)
        concepts_embedded = concepts_embedded.reshape(B * S, f, d)
        questions_embedded = questions_embedded.reshape(B * S, 1, d)
        query_token = self.learnable_query.weight.reshape(1, 1, -1).repeat(B * S, 1, 1)

        if self.use_questions:
            mask_pad = torch.zeros(B * S, 2, dtype=torch.bool, device=mask.device)
            mask = torch.cat([mask_pad, mask], dim=1)
            assert mask.shape == (B * S, f + 2)
            input = torch.cat(
                [query_token, questions_embedded, concepts_embedded], dim=1
            )
            assert input.shape == (B * S, f + 2, d)
        else:
            mask_pad = torch.zeros(B * S, 1, dtype=torch.bool, device=mask.device)
            mask = torch.cat([mask_pad, mask], dim=1)
            assert mask.shape == (B * S, f + 1)
            input = torch.cat([query_token, concepts_embedded], dim=1)
            assert input.shape == (B * S, f + 1, d)

        # mask out irrelevant input
        timestamp_has_entry = torch.any(~mask, dim=1)
        assert (
            timestamp_has_entry == True
        ).all()  # TODO: if so, then this is not needed
        input = input[timestamp_has_entry]
        mask = mask[timestamp_has_entry]

        for block in self.attn_blocks:
            input, _ = block.forward(
                query=input,
                key=input,
                value=input,
                key_padding_mask=mask,
                need_weights=False,
            )

        output = torch.zeros((B * S, d), device=mask.device)
        output[timestamp_has_entry] = input[:, 0, :]
        return output.reshape(B, S, d)


class BaseMultiheadAttention(Module):
    __constants__ = ["batch_first"]

    def __init__(
        self,
        embed_dim,
        num_heads,
        dropout=0.0,
        batch_first=True,
        attn_variant: str = "learnable_alibi_monotonic",
        device=None,
        dtype=None,
    ) -> None:
        if embed_dim <= 0 or num_heads <= 0:
            raise ValueError(
                f"embed_dim and num_heads must be greater than 0,"
                f" got embed_dim={embed_dim} and num_heads={num_heads} instead"
            )
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.embed_dim = embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        assert batch_first == True
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim

        self.in_proj_weight = Parameter(
            torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
        )
        self.register_parameter("q_proj_weight", None)
        self.register_parameter("k_proj_weight", None)
        self.register_parameter("v_proj_weight", None)

        self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        self.out_proj = NonDynamicallyQuantizableLinear(
            embed_dim, embed_dim, bias=True, **factory_kwargs
        )

        self._reset_parameters()

        if attn_variant == "standard":
            self.attn_variant = None
        elif attn_variant == "akt_monotonic":
            self.attn_variant = AKTMonotonicAttention(num_heads=num_heads)
        elif attn_variant in ("alibi_monotonic", "alibi_monotonic_q_k"):
            self.attn_variant = ALiBiMonotonicAttention(num_heads=num_heads)
        elif attn_variant in (
            "learnable_alibi_monotonic",
            "learnable_alibi_monotonic_q_k",
        ):
            self.attn_variant = LearnableALiBiMonotonicAttention(num_heads=num_heads)
        else:
            raise ValueError(f"{attn_variant=} not implemented")
        self.attn_variant_string = attn_variant

    def _reset_parameters(self):
        xavier_uniform_(self.in_proj_weight)
        constant_(self.in_proj_bias, 0.0)
        constant_(self.out_proj.bias, 0.0)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        key_padding_mask: Tensor | None = None,
        attn_mask: Tensor | None = None,
        need_weights=True,
    ) -> tuple[Tensor, Tensor | None]:

        assert need_weights == True

        is_batched = query.dim() == 3

        key_padding_mask = F._canonical_mask(
            mask=key_padding_mask,
            mask_name="key_padding_mask",
            other_type=F._none_or_dtype(attn_mask),
            other_name="attn_mask",
            target_type=query.dtype,
        )

        attn_mask = F._canonical_mask(
            mask=attn_mask,
            mask_name="attn_mask",
            other_type=None,
            other_name="",
            target_type=query.dtype,
            check_other=False,
        )

        # MultiheadAttention does not support NestedTensor outside of its fast path
        assert not (query.is_nested or key.is_nested or value.is_nested)

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = (x.transpose(1, 0) for x in (query, key))
                    value = key
            else:
                query, key, value = (x.transpose(1, 0) for x in (query, key, value))

        attn_output, attn_output_weights = base_multi_head_attention_forward(
            query=query,
            key=key,
            value=value,
            embed_dim_to_check=self.embed_dim,
            num_heads=self.num_heads,
            in_proj_weight=self.in_proj_weight,
            in_proj_bias=self.in_proj_bias,
            dropout_p=self.dropout,
            out_proj_weight=self.out_proj.weight,
            out_proj_bias=self.out_proj.bias,
            attn_variant_string=self.attn_variant_string,
            training=self.training,
            key_padding_mask=key_padding_mask,
            attn_mask=attn_mask,
            attn_variant=self.attn_variant,
        )
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights


def base_multi_head_attention_forward(
    query: Tensor,
    key: Tensor,
    value: Tensor,
    embed_dim_to_check: int,
    num_heads: int,
    in_proj_weight: Tensor,
    in_proj_bias: Tensor,
    dropout_p: float,
    out_proj_weight: Tensor,
    out_proj_bias: Tensor,
    attn_variant_string: str,
    training: bool = True,
    key_padding_mask: Tensor | None = None,
    attn_mask: Tensor | None = None,
    attn_variant: Module | None = None,
) -> tuple[Tensor, Tensor]:

    is_batched = F._mha_shape_check(  # type: ignore
        query, key, value, key_padding_mask, attn_mask, num_heads
    )

    # For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
    # is batched, run the computation and before returning squeeze the
    # batch dimension so that the output doesn't carry this temporary batch dimension.
    if not is_batched:
        # unsqueeze if the input is unbatched
        query, key, value = query.unsqueeze(1), key.unsqueeze(1), value.unsqueeze(1)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.unsqueeze(0)

    # set up shape vars
    tgt_len, bsz, embed_dim = query.shape
    src_len, _, _ = key.shape

    key_padding_mask = F._canonical_mask(
        mask=key_padding_mask,
        mask_name="key_padding_mask",
        other_type=F._none_or_dtype(attn_mask),
        other_name="attn_mask",
        target_type=query.dtype,
    )

    attn_mask = F._canonical_mask(
        mask=attn_mask,
        mask_name="attn_mask",
        other_type=None,
        other_name="",
        target_type=query.dtype,
        check_other=False,
    )

    # expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}
    assert embed_dim == embed_dim_to_check
    # embed_dim can be a tensor when JIT tracing
    if isinstance(embed_dim, torch.Tensor):
        head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
    else:
        head_dim = embed_dim // num_heads
    # embed_dim {embed_dim} not divisible by num_heads {num_heads}
    assert head_dim * num_heads == embed_dim
    assert isinstance(head_dim, int)
    # key shape {key.shape} does not match value shape {value.shape}
    assert key.shape == value.shape

    #
    # compute in-projection
    #
    w_q, w_k, w_v = in_proj_weight.chunk(3)
    b_q, b_k, b_v = in_proj_bias.chunk(3)
    if attn_variant_string in (
        "akt_monotonic",
        "alibi_monotonic_q_k",
        "learnable_alibi_monotonic_q_k",
    ):
        assert torch.equal(query, key)
        # use same projection weights for query as for key
        q, k, v = (
            F.linear(query, w_k, b_k),
            F.linear(key, w_k, b_k),
            F.linear(value, w_v, b_v),
        )
    else:
        q, k, v = (
            F.linear(query, w_q, b_q),
            F.linear(key, w_k, b_k),
            F.linear(value, w_v, b_v),
        )

    # prep attention mask
    if attn_mask is not None:
        # ensure attn_mask's dim is 3
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(
                    f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}."
                )
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(
                    f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
                )
        else:
            raise RuntimeError(
                f"attn_mask's dimension {attn_mask.dim()} is not supported"
            )

    #
    # reshape q, k, v for multihead attention and make em batch first
    #
    q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    # update source sequence length after adjustments
    src_len = k.size(1)

    # merge key padding and attention masks
    if key_padding_mask is not None:
        # expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}
        assert key_padding_mask.shape == (bsz, src_len)
        key_padding_mask = (
            key_padding_mask.view(bsz, 1, 1, src_len)
            .expand(-1, num_heads, -1, -1)
            .reshape(bsz * num_heads, 1, src_len)
        )
        if attn_mask is None:
            attn_mask = key_padding_mask
        else:
            attn_mask = attn_mask + key_padding_mask

    # adjust dropout probability
    if not training:
        dropout_p = 0.0

    #
    # calculate attention and out projection
    #
    _, _, E = q.shape
    q_scaled = q / math.sqrt(E)

    if attn_variant is not None:
        assert attn_mask is not None
        attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
        # here mask is to be handled by attn_variant because of gradient issues
        attn_output_weights = attn_variant.forward(
            attn_output_weights=attn_output_weights, attn_mask=attn_mask
        )
    else:
        # standard (i.e. unchanged) MHA
        if attn_mask is None:
            attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
        else:
            attn_output_weights = torch.baddbmm(
                attn_mask, q_scaled, k.transpose(-2, -1)
            )

    attn_output_weights = F.softmax(attn_output_weights, dim=-1)
    if dropout_p > 0.0:
        attn_output_weights = F.dropout(attn_output_weights, p=dropout_p)

    attn_output = torch.bmm(attn_output_weights, v)

    attn_output = (
        attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
    )
    attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
    attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

    # optionally average attention weights over heads
    # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)

    if not is_batched:
        # squeeze the output if input was unbatched
        attn_output = attn_output.squeeze(1)
        attn_output_weights = attn_output_weights.squeeze(0)

    return attn_output, attn_output_weights


class AKTMonotonicAttention(Module):

    def __init__(self, num_heads: int) -> None:
        super().__init__()

        self.num_heads = num_heads

        self.gamma = Parameter(torch.zeros(self.num_heads, 1, 1))

        xavier_uniform_(self.gamma)

    def forward(
        self,
        attn_output_weights: torch.Tensor,
        attn_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        batch_size_times_num_head, S, _ = attn_output_weights.shape

        scores = attn_output_weights.reshape(-1, self.num_heads, S, S)
        mask = (attn_mask == 0).unsqueeze(0)

        x1 = torch.arange(S).expand(S, -1).to(attn_output_weights.device)
        x2 = x1.transpose(0, 1).contiguous()

        with torch.no_grad():
            scores_ = scores.masked_fill(mask == 0, -1e32)
            scores_ = F.softmax(scores_, dim=-1)
            scores_ = scores_ * mask.float()
            distcum_scores = torch.cumsum(scores_, dim=-1)
            disttotal_scores = torch.sum(scores_, dim=-1, keepdim=True)
            position_effect = (
                torch.abs(x1 - x2)[None, None, :, :].type(torch.FloatTensor).to(attn_output_weights.device)  # type: ignore
            )
            dist_scores = torch.clamp(
                (disttotal_scores - distcum_scores) * position_effect, min=0.0
            )
            dist_scores = dist_scores.sqrt().detach()

        gamma = -1.0 * nn.Softplus()(self.gamma).unsqueeze(0)
        total_effect = torch.clamp(
            torch.clamp((dist_scores * gamma).exp(), min=1e-5), max=1e5
        )
        scores = scores * total_effect
        scores.masked_fill_(mask == 0, -1e32)
        scores = scores.reshape(batch_size_times_num_head, S, S)
        return scores


def get_slopes(n):
    def get_slopes_power_of_2(n):
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )


def inverse_softplus(x):
    return x + torch.log(-torch.expm1(-x))


class ALiBiMonotonicAttention(Module):

    def __init__(self, num_heads: int, use_multi_theta: bool = True) -> None:
        super().__init__()

        self.num_heads = num_heads
        self.use_multi_theta = use_multi_theta

        assert use_multi_theta == True

        thetas = torch.tensor(get_slopes(self.num_heads)).reshape(self.num_heads, 1, 1)

        self.register_buffer("thetas", thetas)

        # logger.info(f"{self.thetas.squeeze()=}")

    def forward(
        self,
        attn_output_weights: torch.Tensor,
        attn_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        batch_size_times_num_head, S, _ = attn_output_weights.shape
        assert (batch_size_times_num_head % self.num_heads) == 0

        attn_output_weights = attn_output_weights + attn_mask

        seq = torch.arange(S).expand(S, -1).to(attn_output_weights.device)

        # .tril helps with NaNs b/c of multiplication with parameters
        if attn_mask is not None:
            attn_mask = attn_mask.repeat(
                batch_size_times_num_head // attn_mask.shape[0], 1, 1
            )

        distance_matrix = torch.tril(seq - seq.t(), diagonal=-1)

        _position_effect = self.thetas * distance_matrix

        # clone _position_effect to prevent torch inplace modification error
        if self.use_multi_theta:
            batch_size = batch_size_times_num_head // self.num_heads
            position_effect = _position_effect.clone().repeat(batch_size, 1, 1)
        else:
            position_effect = _position_effect.clone().repeat(
                batch_size_times_num_head, 1, 1
            )

        # account for padding mask => prevent multiplication of -Inf and 0.0
        if attn_mask is not None:
            position_effect[attn_mask == float("-Inf")] = 1.0

        attn_output_weights = attn_output_weights + position_effect

        # ensure that code above does not escape masking
        attn_output_weights = attn_output_weights + attn_mask

        return attn_output_weights


class LearnableALiBiMonotonicAttention(Module):

    def __init__(self, num_heads: int, use_multi_theta: bool = True) -> None:
        super().__init__()

        self.num_heads = num_heads
        self.use_multi_theta = use_multi_theta

        assert use_multi_theta == True

        self.thetas = Parameter(
            inverse_softplus(
                torch.tensor(get_slopes(self.num_heads)).reshape(self.num_heads, 1, 1)
            )
        )

        # logger.info(f"{self.thetas.squeeze()=}")

    def forward(
        self,
        attn_output_weights: torch.Tensor,
        attn_mask: torch.Tensor | None,
    ) -> torch.Tensor:
        batch_size_times_num_head, S, _ = attn_output_weights.shape
        assert (batch_size_times_num_head % self.num_heads) == 0

        attn_output_weights = attn_output_weights + attn_mask

        seq = torch.arange(S).expand(S, -1).to(attn_output_weights.device)

        # .tril helps with NaNs b/c of multiplication with parameters
        if attn_mask is not None:
            attn_mask = attn_mask.repeat(
                batch_size_times_num_head // attn_mask.shape[0], 1, 1
            )

        distance_matrix = torch.tril(seq - seq.t(), diagonal=-1)

        _position_effect = F.softplus(self.thetas) * distance_matrix

        # clone _position_effect to prevent torch inplace modification error
        if self.use_multi_theta:
            batch_size = batch_size_times_num_head // self.num_heads
            position_effect = _position_effect.clone().repeat(batch_size, 1, 1)
        else:
            position_effect = _position_effect.clone().repeat(
                batch_size_times_num_head, 1, 1
            )

        # account for padding mask => prevent multiplication of -Inf and 0.0
        if attn_mask is not None:
            position_effect[attn_mask == float("-Inf")] = 1.0

        attn_output_weights = attn_output_weights + position_effect

        # ensure that code above does not escape masking
        attn_output_weights = attn_output_weights + attn_mask

        return attn_output_weights


class PositionalEncoding(nn.Module):
    """Adapted from https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    (very much alike the autobots implementation)"""

    def __init__(self, d_model: int, dropout: float = 0.0, max_len: int = 100):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.max_len = max_len

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # adjust shape to data shape
        self.register_buffer("pe", pe)

    def forward(self, x: Tensor) -> Tensor:
        if x.dim() == 3:
            B, T, d = x.shape
            assert T <= self.max_len
            x = x + self.pe[:, :T]
        elif x.dim() == 4:
            B, T, f, d = x.shape
            assert T <= self.max_len
            x = x + self.pe[:, :T].view(1, T, 1, d)
        else:
            raise NotImplementedError()

        return self.dropout(x)
