import torch
from torch import nn

from kge.models.fusing_function.base import ParametricFusingFunction


class EncoderBottleneck(ParametricFusingFunction):
    def __init__(
        self,
        dimension: int,
        bottleneck_dimension: int,
        input_drop: float = 0.1,
        hidden_drop: float = 0.2,
        output_drop: float = 0.1,
    ):
        super().__init__()
        self.bottleneck_dimension = bottleneck_dimension
        self.input_dropout = nn.Dropout(input_drop)
        self.mlp = nn.Sequential(
            # First layer: dimension -> bottleneck_dimension
            nn.Linear(dimension, bottleneck_dimension),
            nn.BatchNorm1d(bottleneck_dimension),
            nn.ReLU(),
            nn.Dropout(hidden_drop),
            # Second layer: bottleneck_dimension -> dimension
            nn.Linear(bottleneck_dimension, dimension),
            nn.BatchNorm1d(dimension),
            nn.ReLU(),
            nn.Dropout(hidden_drop),
            # Third layer: dimension -> dimension
            nn.Linear(dimension, dimension),
            nn.BatchNorm1d(dimension),
            nn.ReLU(),
            nn.Dropout(output_drop),
        )

    def forward(self, s: torch.Tensor, r: torch.Tensor, **kwargs) -> torch.Tensor:
        x = s * r
        x = self.input_dropout(x)
        return self.mlp(x)
