import torch
import torch.nn as nn

class MLP(nn.Module):
    """ Neural Network """
    def __init__(self, hidden_dims, net_act, bias=False):
        super(MLP, self).__init__()
        self.hidden_dims = hidden_dims
        self.net_act = net_act
        layers = []
        for d_in, d_out in zip(hidden_dims[:-2], hidden_dims[1:-1]):
            layers.append(nn.Linear(d_in, d_out, bias=bias))
            if net_act is not None:
                layers.append(self._get_activation())
        layers.append(nn.Linear(hidden_dims[-2], hidden_dims[-1], bias=bias))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor):
        return self.model(x.view(x.size(0), -1))

    def _get_activation(self):
        if self.net_act == 'sigmoid':
            return nn.Sigmoid()
        elif self.net_act == 'relu':
            return nn.ReLU(inplace=True)
        elif self.net_act == 'leakyrelu':
            return nn.LeakyReLU(negative_slope=0.01, inplace=True)
        else:
            exit(f'unknown activation function: {self.net_act}')
    
    def get_param(self):
        if len(self.hidden_dims) == 2: # Linear
            return self.model[0].weight.data.clone().detach().cpu().view(-1)
        else:
            return self.model



class LogisticLoss(nn.Module):
    ''' Logistic Loss, for labels in {+1, -1} '''
    def __init__(self, reduction='mean'):
        super(LogisticLoss, self).__init__()
        self.reduction = reduction
    
    def forward(self, pred, label):
        loss = torch.logaddexp(torch.zeros_like(pred), -label * pred)
        if self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'mean':
            return loss.mean()
        else:
            return loss
