import torch
from einops import rearrange
from torch import nn

from .base import ParametricFusingFunction


class ConvE(ParametricFusingFunction):
    def __init__(
        self,
        dimension: int,
        channels: int = 32,
        kernel_size: tuple[int, int] = (3, 3),
        input_drop: float = 0.2,
        hidden_drop: float = 0.2,
        output_drop: float = 0.3,
        reshape_height: int = 8,
    ):
        super().__init__()

        self.dimension = dimension
        self.reshape_height = reshape_height
        self.reshape_width = dimension // reshape_height
        assert self.reshape_height * self.reshape_width == dimension, (
            f"dimension ({dimension}) must be divisible by reshape_height ({reshape_height})"
        )

        # Input normalization (optional)
        self.input_bn = nn.BatchNorm2d(1)

        # Convolutional feature extractor
        conv_output_height = 2 * self.reshape_height - kernel_size[0] + 1
        conv_output_width = self.reshape_width - kernel_size[1] + 1
        flatten_size = channels * conv_output_height * conv_output_width

        self.conv_block = nn.Sequential(
            nn.Dropout(input_drop),
            nn.Conv2d(in_channels=1, out_channels=channels, kernel_size=kernel_size, padding=0),
            nn.BatchNorm2d(channels),
            nn.ReLU(),
            nn.Dropout2d(hidden_drop),
        )

        # Fully connected projection
        self.linear = nn.Linear(flatten_size, dimension)

        # Output transformation
        self.output_block = nn.Sequential(
            nn.Dropout(output_drop),
            nn.BatchNorm1d(dimension),
            # nn.ReLU(),  # In the ConvE paper they add a ReLU here, but it performs worse
        )

    def forward(self, s: torch.Tensor, r: torch.Tensor, **kwargs) -> torch.Tensor:
        # s, r shape: (batch_size, dimension)
        s = rearrange(s, "b (h w) -> b h w", h=self.reshape_height)
        r = rearrange(r, "b (h w) -> b h w", h=self.reshape_height)
        stacked = torch.cat([s, r], dim=1)  # shape: (batch_size, 2 * reshape_height, reshape_width)
        stacked = rearrange(stacked, "b h w -> b 1 h w")

        # Apply input batch norm
        x = self.input_bn(stacked)
        x = self.conv_block(x)
        x = rearrange(x, "b c h w -> b (c h w)")
        x = self.linear(x)
        x = self.output_block(x)
        return x
