import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from Models.Layers.Layers import MaskedLinear, MaskedConv2d

class mlp_3(nn.Module):
    def __init__(self, input_dim):
        super(mlp_3, self).__init__()

        self.fc1 = MaskedLinear(input_dim, input_dim//2)
        self.bn1 = nn.BatchNorm1d(input_dim//2)
        self.fc2 = MaskedLinear(input_dim//2, input_dim//4)
        self.bn2 = nn.BatchNorm1d(input_dim//4)
        self.fc3 = MaskedLinear(input_dim//4, input_dim//8)
        self.bn3 = nn.BatchNorm1d(input_dim//8)
        self.fc4 = MaskedLinear(input_dim//8, 10)

        self._initialize_weights()

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        x = F.relu(self.bn3(self.fc3(x)))
        x = self.fc4(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, (MaskedConv2d, MaskedLinear)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def set_masks(self, weight_mask,bias_mask):
        i = 0
        self.fc1.set_mask(weight_mask[i],bias_mask[i])
        i = i + 1
        self.fc2.set_mask(weight_mask[i],bias_mask[i])
        i = i + 1
        self.fc3.set_mask(weight_mask[i],bias_mask[i])
        i = i + 1
        self.fc4.set_mask(weight_mask[i],bias_mask[i])

def mlp3(input_shape, num_classes):
    input_dim = input_shape[0]*input_shape[1]*input_shape[2]
    return mlp_3(input_dim)