from torch import nn
import torch
import torch.nn.utils.weight_norm as weightNorm
from torch.distributions import Bernoulli, RelaxedBernoulli

def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        nn.init.zeros_(m.bias)



class BasicNet(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, num_class):
        super(BasicNet, self).__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer1.apply(init_weights)
        self.bn1 = nn.BatchNorm1d(n_hidden_1, affine=True)
        self.relu = nn.ReLU(inplace=True)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer2.apply(init_weights)
        self.bn2 = nn.BatchNorm1d(n_hidden_2, affine=True)
        self.maskI = nn.Parameter(torch.rand((n_hidden_2), requires_grad=True))
        self.classifier = weightNorm(nn.Linear(n_hidden_2, num_class), name='weight')
        self.classifier.apply(init_weights)

    def forward(self, x):
        x = self.layer1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer2(x)
        x = self.bn2(x)
        probs = nn.Sigmoid()(self.maskI)
        mask_dist = Bernoulli(probs)
        hard_mask = mask_dist.sample()
        soft_mask = probs
        mask = (hard_mask - soft_mask).detach() + soft_mask
        feat = mask * x
        out = self.classifier(feat)

        return out, feat, x, self.maskI


class Decoder(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2):
        super(Decoder, self).__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer1.apply(init_weights)
        self.bn1 = nn.BatchNorm1d(n_hidden_1, affine=True)
        self.relu = nn.ReLU(inplace=True)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer2.apply(init_weights)

    def forward(self, xs, xi):
        x = torch.cat((xs, xi), 1)
        x = self.layer1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer2(x)

        return x

class MaskSNet(nn.Module):
    def __init__(self, feat_dim):
        super(MaskSNet, self).__init__()
        self.mask = nn.Parameter(torch.rand((feat_dim), requires_grad=True))

    def forward(self, feat):
        probs = nn.Sigmoid()(self.mask)
        mask_dist = Bernoulli(probs)
        hard_mask = mask_dist.sample()
        soft_mask = probs
        mask = (hard_mask - soft_mask).detach() + soft_mask
        feat = mask * feat

        return feat, self.mask