import math

import torch
from torch import nn, Tensor

class PropertyEncoding(nn.Module):

    def __init__(self, nf: int):
        super().__init__()

        self.nf = nf
        # more stable way to compute 1 / (10000 ** (2i / nf))
        BASE = 10000.0
        self.inv_periods = torch.exp(torch.arange(0, nf, 2) * (-math.log(BASE) / nf))

    def forward(self, prop: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``(batch_size,)``
        """
        batch_size = prop.size(0)
        pe = torch.zeros((batch_size, self.nf))
        pe[:, 0::2] = torch.sin(prop.unsqueeze(1) * self.inv_periods)
        pe[:, 1::2] = torch.cos(prop.unsqueeze(1) * self.inv_periods)
        return pe
