import torch
from einops import rearrange, reduce
from torch import nn

from .base import ParametricFusingFunction


class ConvKB(ParametricFusingFunction):
    def __init__(
        self,
        channels: int = 50,
        input_dropout: float = 0.1,
        layer_dropout: float = 0.1,
    ):
        super().__init__()
        self.conv = nn.Conv1d(in_channels=1, out_channels=channels, kernel_size=3)
        self.input_dropout = nn.Dropout(input_dropout)
        self.layer_dropout = nn.Dropout(layer_dropout)

    def forward(self, s: torch.Tensor, r: torch.Tensor, **kwargs) -> torch.Tensor:
        x = torch.cat([s, r], dim=1)
        x = rearrange(x, "b two f -> (b f) 1 two")
        x = self.input_dropout(x)
        x = self.conv(x)  # (batch * feat_dim, channels, 1)
        x = rearrange(x, "(b f) c 1 -> b c f", b=s.size(0))
        x = self.layer_dropout(x)
        return reduce(x, "b c f -> b f", "mean")
