import torch


# Apparently only works with Python 1.5 (don't know why)
# @torch.jit.script
# @torch.no_grad()
def get_angular_buckets(X: torch.Tensor, nnodes: torch.Tensor, nbuckets: int, nhashes: int):
    """
    As used in: https://arxiv.org/abs/2001.04451
    Paper: https://arxiv.org/pdf/1509.02897.pdf
    Adapted from: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_reformer.py
    """
    batch_size, max_nodes2, emb_size = X.shape
    max_nodes = max_nodes2 // 2

    # We sample a different random rotation for each round of hashing to
    # decrease the probability of hash misses.
    assert (
        nbuckets % 2 == 0
    ), f"There should be an even number of buckets, but `nbuckets`: {nbuckets}"
    rotation_size = nbuckets

    # create a random emb_dim x nhashes x nbuckets/2 tensor
    rotations_shape = (X.shape[-1], nhashes, rotation_size // 2)
    random_rotations = torch.randn(rotations_shape, device=X.device, dtype=X.dtype)

    # Output dim: Batch_Size x nhashes x Seq_Len x nbuckets/2
    rotated_vectors = torch.einsum("btd,dhr->bhtr", X, random_rotations)

    rotated_vectors = torch.cat([rotated_vectors, -rotated_vectors], dim=-1)
    buckets = torch.argmax(rotated_vectors, dim=-1)

    # Mask non-nodes
    mask_padding = torch.zeros((batch_size, nhashes, 2 * max_nodes), dtype=torch.bool, device=X.device)
    mask_padding[:, :, :max_nodes] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, nhashes, -1)
                                      >= nnodes[0, :, None, None])
    mask_padding[:, :, max_nodes:] = (torch.arange(max_nodes, dtype=torch.long, device=X.device).expand(batch_size, nhashes, -1)
                                      >= nnodes[1, :, None, None])
    buckets.masked_fill_(mask_padding, nbuckets)

    return buckets
