import math
import torch

# pylint:disable=no-member

def create_relative_position(
    query_len: int, key_len: int, bidirectional=True, num_buckets: int = 32, max_distance: int = 128
) -> torch.LongTensor:
    """
        Args:
            query_len: an integer
            key_len: an integer
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer
        Returns:
            a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """
    query_position = torch.arange(key_len, dtype=torch.long)[:, None]
    key_position = torch.arange(key_len, dtype=torch.long)[None, :]
    # shape (klen, klen)
    relative_position = key_position - query_position

    ret = 0
    n = -relative_position
    if bidirectional:
        num_buckets //= 2
        ret += (n < 0).to(torch.long) * num_buckets  # mtf.to_int32(mtf.less(n, 0)) * num_buckets
        n = torch.abs(n)
    else:
        n = torch.max(n, torch.zeros_like(n))
    # now n is in the range [0, inf)

    # half of the buckets are for exact increments in positions
    max_exact = num_buckets // 2
    is_small = n < max_exact

    # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
    val_if_large = max_exact + (
        torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
    ).to(torch.long)
    val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))

    # shape (key_len, key_len)
    ret += torch.where(is_small, n, val_if_large)

    # shape (query_len, key_len)
    ret = ret[-query_len:]
    return ret
