import torch
from torch import nn
from torchdp import PrivacyEngine
from torch import autograd
from torchdp import PerSampleGradientClipper
from torch.utils.data import DataLoader, TensorDataset

class MLPNet(nn.Module):
    """
    Implement a MLP with  single hidden layer. The choice of activation
    function can be passed as argument when init the network
    """

    def __init__(self, options={'num_feats': 20, 'activation': 'relu', 'width': 20}):
        super(MLPNet, self).__init__()
        if options['activation'] == 'relu':
            self.act_func = nn.ReLU()
        if options['activation'] == 'tanh':
            self.act_func = nn.Tanh()
        else:
            self.act_func = nn.Sigmoid()
        self.sigmoid = nn.Sigmoid()
        self.input_layer = nn.Linear(options['num_feats'], options['width'])
        self.o_layer = nn.Linear(int(options['num_feats'] / 2), 1)

    def forward(self, x):
        output = self.act_func(self.input_layer(x))
        output = self.o_layer(output)
        return self.sigmoid(output)


class LRNet(nn.Module):
    """
    Logistic clfression network (NO hidden layer). So no activation function here. 
    """

    def __init__(self, options):
        super(LRNet, self).__init__()
        self.fc1 = nn.Linear(options['num_feats'], 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.fc1(x))