from spaghettini import quick_register
import torch
from torch import nn

from torch.nn import Linear, Bilinear


@quick_register
class MultiplicativeDense(nn.Module):
    def __init__(self, in1_features, in2_features, out_features, bias=True):
        super().__init__()
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.bias = bias

        # Components.
        self.bilinear = Bilinear(in1_features=in1_features, in2_features=in2_features, out_features=out_features,
                                 bias=bias)
        self.in1_linear = Linear(in_features=in1_features, out_features=out_features, bias=False)
        self.in2_linear = Linear(in_features=in2_features, out_features=out_features, bias=False)

    def forward(self, input1, input2):
        mult_out = self.bilinear(input1=input1, input2=input2)
        linear1_out = self.in1_linear(input=input1)
        linear2_out = self.in2_linear(input=input2)

        return mult_out + linear1_out + linear2_out


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.multiplicative
    """
    test_num = 0

    if test_num == 0:
        # Create dummy inputs.
        in1_feats, in2_feats, out_feats = 5, 17, 12
        in1 = torch.ones(size=(1, in1_feats))
        in2 = torch.ones(size=(1, in2_feats))

        # Instantiate layer.
        mult_layer = MultiplicativeDense(in1_features=in1_feats, in2_features=in2_feats, out_features=out_feats)

        # Take a forward pass.
        mult_layer(in1, in2)
