import torch
import numpy as np


class TimeEncode(torch.nn.Module):
	# Time Encoding proposed by TGAT
	def __init__(self, dimension):
		super(TimeEncode, self).__init__()

		self.dimension = dimension
		self.w = torch.nn.Linear(1, dimension)

		self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
																			 .float().reshape(dimension, -1))
		self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())

	def forward(self, t):
		# t has shape [batch_size, seq_len]
		# Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
		t = t.unsqueeze(dim=2)

		# output has shape [batch_size, seq_len, dimension]
		output = torch.cos(self.w(t))

		return output
