import torch
import torch.nn as nn


class SchNetFilter(nn.Module):
    START = 0
    INTERVAL = 0.2
    SLOT = 51
    GAMMA = 10

    def __init__(self, out_dim, use_cuda):
        super(SchNetFilter, self).__init__()
        self.mu = torch.FloatTensor([[self.START + i * self.INTERVAL for i in range(self.SLOT)]])
        if use_cuda:
            self.mu = self.mu.cuda()
        self.dense1 = nn.Linear(self.SLOT, out_dim, bias=False)
        self.act1 = nn.Softplus()
        self.dense2 = nn.Linear(out_dim, out_dim, bias=False)
        self.act2 = nn.Softplus()

    def forward(self, dis: torch.Tensor) -> torch.Tensor:
        x = torch.exp(-self.GAMMA * (dis - self.mu) ** 2)
        x = self.act1(self.dense1(x))
        x = self.act2(self.dense2(x))
        return x
