import torch
import torch.nn as nn
from SourceCode.ModelModule.SparseSoftmax import Sparsemax


class EmbeddingNet(nn.Module):
    def __init__(self, source_input_dim, source_embedding_dim, source_embedding_hidden_layer_size):
        super().__init__()
        self.in_dim = source_input_dim
        self.out_dim = source_embedding_dim
        self.hidden_layer_size = source_embedding_hidden_layer_size

        self.hidden1 = nn.Linear(self.in_dim, self.hidden_layer_size)
        self.bn1 = nn.BatchNorm1d(self.hidden_layer_size)
        self.relu1 = nn.ReLU()
        self.hidden2 = nn.Linear(self.hidden_layer_size, (self.hidden_layer_size + self.out_dim) // 2)
        self.bn2 = nn.BatchNorm1d((self.hidden_layer_size + self.out_dim) // 2)
        self.relu2 = nn.ReLU()
        self.hidden3 = nn.Linear((self.hidden_layer_size + self.out_dim) // 2, self.out_dim)
        self.relu3 = nn.ReLU()

    def forward(self, x):
        y = self.hidden1(x)
        y = self.bn1(y)
        y = self.relu1(y)
        y = self.hidden2(y)
        y = self.bn2(y)
        y = self.relu2(y)
        y = self.hidden3(y)
        y = self.relu3(y)
        return y
