import math
import warnings

import torch
import torch.nn as nn
from torch import Tensor

# Using masked tensors will raise this warning. It is printed a lot during torch compilation.
warnings.filterwarnings(
    "ignore",
    message=(
        "The PyTorch API of MaskedTensors is in prototype stage "
        "and will change in the near future. Please open a Github issue "
        "for features requests and see our documentation on the torch.masked "
        "module for further information about the project."
    ),
)


class ALiBi(nn.Module):
    def __init__(self, n_heads: int, base: float = math.pow(2, -0.5), scale: float = 10.0):
        super().__init__()

        slopes = scale * base ** torch.arange(n_heads)
        self.register_buffer("slopes", slopes)

    def forward(self, d: Tensor) -> Tensor:
        """Applies the ALiBi slopes to the distance matrix.

        Also turns the distance matrix into a negative bias.
        """
        return torch.einsum("bls,h->bhls", -d, self.slopes)

    @staticmethod
    def normalize(d: Tensor, m: Tensor) -> Tensor:
        """Makes d scaling invariant.

        ---
        Args:
            d: Distance matrix.
                Shape of [batch_size, n_cities, n_cities].
            m: Mask.
                Shape of [batch_size, n_cities, n_cities].

        ---
        Returns:
            The normalized distance matrix.
                Shape of [batch_size, n_cities, n_cities].
        """
        std = torch.masked.masked_tensor(d, m).std(dim=(1, 2), keepdim=True).to_tensor(1.0)
        return d / std
