import os
import torch
import numpy as np
from torch import nn
import torch.nn.functional as fcnal
import math



    
    
class mlleaks_cnn(nn.Module): 
    def __init__(self, n_in=3, n_classes=10, n_filters=64, size=128): 
        super(mlleaks_cnn, self).__init__()
        
        self.n_filters = n_filters
        
        self.conv_block_1 = nn.Sequential(
            nn.Conv2d(n_in, n_filters, kernel_size=5, stride=1, padding=2), 
            nn.BatchNorm2d(n_filters), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.conv_block_2 = nn.Sequential(
            nn.Conv2d(n_filters, 2*n_filters, kernel_size=5, stride=1,
                      padding=2), 
            nn.BatchNorm2d(2*n_filters), 
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ) 
        self.fc = nn.Linear(2*n_filters * 8 * 8, size)
        self.output = nn.Linear(2*n_filters, n_classes)
        
    def forward(self, x): 
        x = self.conv_block_1(x)
        x = self.conv_block_2(x)
        x = x.view(-1, 2*self.n_filters * 8 * 8)
        x = self.fc(x)
        out = self.output(x)
        
        return out

    
class mlleaks_mlp(nn.Module): 
    def __init__(self, n_in=3, n_classes=1, n_filters=64, size=64): 
        super(mlleaks_mlp, self).__init__()
        
        self.hidden = nn.Linear(n_in, n_filters)
        # self.bn = nn.BatchNorm1d(n_filters)
        self.output = nn.Linear(n_filters, n_classes)
        
    def forward(self, x): 
        x = fcnal.sigmoid(self.hidden(x))
        # x = self.bn(x)
        out = self.output(x)
        out = fcnal.sigmoid(self.output(x))
        
        return out
    
def weights_init(m):
    """
    Initializes weights of layers of model m
    """
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear): 
        nn.init.xavier_normal_(m.weight.data)
        nn.init.constant_(m.bias, 0)


