import torch
import torch.nn as nn
from torch.nn import functional as F


class SigmoidBias(nn.Module):
    """
    sigmoid function with the bias parameter
    """

    def __init__(self, output_features=1, bias=True):
        super(SigmoidBias, self).__init__()
        if bias:
            uniform = 0.1*(1-2*torch.rand(output_features))
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            self.register_parameter('bias', None)

    def forward(self, input):
        output = torch.sigmoid(input)
        if self.bias is not None:
            output = output + self.bias.unsqueeze(0).expand_as(output)
        return output


class LogisticRegression(torch.nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred


class MLP(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_dim=256):
        super(MLP, self).__init__()
        # Number of input features is input_dim.
        self.layer_1 = nn.Linear(in_channels, hidden_dim)
        self.layer_2 = nn.Linear(hidden_dim, hidden_dim)
        self.layer_out = nn.Linear(hidden_dim, out_channels)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)
        self.batchnorm1 = nn.BatchNorm1d(hidden_dim)
        self.batchnorm2 = nn.BatchNorm1d(hidden_dim)

    def forward(self, inputs):
        x = self.relu(self.layer_1(inputs))
        #x = self.batchnorm1(x)
        x = self.relu(self.layer_2(x))
        #x = self.batchnorm2(x)
        #x = self.dropout(x)
        x = self.relu(self.layer_2(x))
        x = self.relu(self.layer_2(x))

        x = self.relu(self.layer_2(x))
        x = self.relu(self.layer_2(x))
        #x = self.relu(self.layer_2(x))
        #x = self.batchnorm2(x)
        #x = self.dropout(x)
        x = self.layer_out(x)
        x = torch.sigmoid(x)
        return x


class BinaryClassification(nn.Module):
    def __init__(self, input_dim):
        super(BinaryClassification, self).__init__()
        # Number of input features is input_dim.
        self.layer_1 = nn.Linear(input_dim, 128)
        self.layer_2 = nn.Linear(128, 128)
        self.layer_out = nn.Linear(128, 1)

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.1)
        self.batchnorm1 = nn.BatchNorm1d(128)
        self.batchnorm2 = nn.BatchNorm1d(128)

    def forward(self, inputs):
        x = self.relu(self.layer_1(inputs))
        x = self.batchnorm1(x)
        x = self.relu(self.layer_2(x))
        x = self.batchnorm2(x)
        x = self.dropout(x)
        x = self.relu(self.layer_2(x))
        x = self.batchnorm2(x)
        x = self.dropout(x)
        x = self.layer_out(x)
        x = torch.sigmoid(x)
        return x
