# model definitions

import torch
import torch.nn.functional as F

class TwoLayerMLP(torch.nn.Module):
    def __init__(self, input_size, num_hidden):
        super().__init__()
        self.input_size = input_size
        self.num_hidden = num_hidden
        self.hidden_layer = torch.nn.Linear(input_size, num_hidden, bias=False)
        self.output_layer = torch.nn.Linear(num_hidden, 1, bias=False)

    def forward(self, x):
        hid = F.relu(self.hidden_layer(x))
        return torch.sigmoid(self.output_layer(hid))

    def feature_forward(self, x):
        hid = torch.sigmoid(self.hidden_layer(x))
        features = hid * self.output_layer.weight.data
        return features


class MultiClassMLP(torch.nn.Module):
    def __init__(self, input_size, num_hidden, output_size):
        super().__init__()
        self.input_size = input_size
        self.num_hidden = num_hidden
        self.hidden_layer = torch.nn.Linear(input_size, num_hidden, bias=False)
        self.output_layer = torch.nn.Linear(num_hidden, output_size, bias=False)

    def forward(self, x):
        hid = F.relu(self.hidden_layer(x))
        return self.output_layer(hid)
