import torch 
from torch import nn 
import torch.nn.functional as F


class SpamNN(nn.Module):
    def __init__(self):
        super(SpamNN, self).__init__()
        self.fc1 = nn.Linear(in_features=57, out_features=50)
        self.fc2 = nn.Linear(in_features=50, out_features=10)
        self.fc3 = nn.Linear(in_features=10, out_features=1)
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh() 
        self.init_model()
    
    def forward(self, x):
        x = self.tanh(self.fc1(x))
        x = self.tanh(self.fc2(x))
        x = self.sigmoid(self.fc3(x))
        return x 
    
    def init_model(self):
        for param in self.parameters():
            if len(param.shape) > 1:  
                nn.init.normal_(param, mean=0, std=1)
            else:                     
                nn.init.constant_(param, 0) 


class BCELossWithL2(nn.Module):    
    def __init__(self, model, lambda_reg=1e-3):
        super(BCELossWithL2, self).__init__()
        self.model = model
        self.lambda_reg = lambda_reg  
        self.bce_loss = nn.BCEWithLogitsLoss() 

    def forward(self, outputs, labels):
        labels = torch.unsqueeze(labels, 1) 
        loss = self.bce_loss(outputs, labels)  
        l2_reg = 0  
        for param in self.model.parameters(): 
            l2_reg += torch.norm(param, p=2) ** 2  
        loss += 0.5 * self.lambda_reg * l2_reg  
        return loss