import torch.nn as nn
import torch

class TriTimeEmbedding(nn.Module):
    def __init__(self, input_size, embed_size=32):
        super().__init__()
        self.Wt = nn.Linear(input_size, embed_size//2, bias=False)

    def forward(self, interval):
        if len(interval.shape) == 1:
            interval = interval.reshape(1,-1)
        phi = self.Wt(interval)
        pe_sin = torch.sin(phi)
        pe_cos = torch.cos(phi)
        pe = torch.cat([pe_sin, pe_cos], dim=-1)
        return pe
