import torch
import torch.nn as nn

class EncodingModel(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        layers = []
        layers.append(nn.Linear(input_dim, 128))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(128, 128))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(128, 128))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        #x_shape = x.size()
        #x = x.view(-1, 2)
        feat = self.model(x)
        #feat = feat.view(*x_shape)
        return torch.sum(feat, dim=1)

class ComparisonModel(nn.Module):
    def __init__(self):
        super().__init__()
        layers = []
        layers.append(nn.Linear(256, 256))
        layers.append(nn.ReLU())
        layers.append(nn.Linear(256, 1))
        layers.append(nn.Sigmoid())
        self.model = nn.Sequential(*layers)
        self.loss = nn.BCELoss()

    def forward(self, set_embed, query_embed, label):
        set_embed = set_embed.expand_as(query_embed)
        x = torch.cat((set_embed, query_embed), 1)
        pred = self.model(x)
        loss = self.loss(pred, label)
        return pred, loss
