import torch
import torch.nn as nn


class RefineNet(nn.Module):
    def __init__(self, source_embedding_dim, source_refined_dim, source_refined_hidden_layer_size):
        super().__init__()
        self.in_dim = source_embedding_dim
        self.out_dim = source_refined_dim
        self.hidden_layer_size = source_refined_hidden_layer_size

        self.hidden1 = nn.Linear(self.in_dim, self.hidden_layer_size)
        self.bn1 = nn.BatchNorm1d(self.hidden_layer_size)
        self.relu1 = nn.LeakyReLU()
        self.hidden2 = nn.Linear(self.hidden_layer_size, (self.hidden_layer_size + self.out_dim) // 2)
        self.bn2 = nn.BatchNorm1d((self.hidden_layer_size + self.out_dim) // 2)
        self.relu2 = nn.LeakyReLU()
        self.hidden3 = nn.Linear((self.hidden_layer_size + self.out_dim) // 2, self.out_dim)


    def forward(self, x):
        y = self.hidden1(x)
        y = self.bn1(y)
        y = self.relu1(y)
        y = self.hidden2(y)
        y = self.bn2(y)
        y = self.relu2(y)
        y = self.hidden3(y)
        return y


class DestRefineNet(nn.Module):
    def __init__(self, dest_embedding_dim, dest_refined_dim, dest_refined_hidden_layer_size):
        super().__init__()
        self.in_dim = dest_embedding_dim
        self.out_dim = dest_refined_dim
        self.hidden_layer_size = dest_refined_hidden_layer_size
        self.hidden1 = nn.Linear(self.in_dim, self.hidden_layer_size)
        self.ln1 = nn.LayerNorm(self.hidden_layer_size)
        self.relu1 = nn.LeakyReLU()
        self.hidden2 = nn.Linear(self.hidden_layer_size, (self.hidden_layer_size + self.out_dim) // 2)
        self.ln2 = nn.LayerNorm((self.hidden_layer_size + self.out_dim) // 2)
        self.relu2 = nn.LeakyReLU()
        self.hidden3 = nn.Linear((self.hidden_layer_size + self.out_dim) // 2, self.out_dim)

    def forward(self, x):
        y = self.hidden1(x)
        # y = self.ln1(y)
        y = self.relu1(y)
        y = self.hidden2(y)
        # y = self.ln2(y)
        y = self.relu2(y)
        y = self.hidden3(y)
        return y