import torch
import torch.nn as nn
import math


class StochasticGates(nn.Module):
    def __init__(self, size, sigma, lam, gate_init=None):
        super().__init__()
        self.size = size
        if gate_init is None:
            mus = 0.5 * torch.ones(size)
        else:
            mus = torch.tensor(gate_init, dtype=torch.float32)
        self.mus = nn.Parameter(mus, requires_grad=True)
        self.sigma = sigma
        self.lam = lam

    def forward(self, x, L):
        gaussian = self.sigma * torch.randn(self.mus.size()) * self.training
        shifted_gaussian = self.mus + gaussian.to(x.device)
        z = self.make_bernoulli(shifted_gaussian)
        new_x = x * z
        return new_x, L

    @staticmethod
    def make_bernoulli(z):
        return torch.clamp(z, 0.0, 1.0)

    def get_reg(self):
        return self.lam * torch.sum((1 + torch.erf((self.mus / self.sigma) / math.sqrt(2))) / 2)

    def get_gates(self):
        return self.make_bernoulli(self.mus)


class StochasticGates2(nn.Module):
    def __init__(self, size, sigma, lam, gate_init=None):
        super().__init__()
        self.size = size
        if gate_init is None:
            #mus = 0.5 * torch.ones(size)
            mus = torch.randint(2,(size,),dtype=torch.float32)
        else:
            mus = torch.tensor(gate_init, dtype=torch.float32)
        self.mus = nn.Parameter(mus, requires_grad=True)
        self.sigma = sigma
        self.lam = lam

    def forward(self, x, L):
        gaussian = self.sigma * torch.randn(self.mus.size()) * self.training
        shifted_gaussian = self.mus + gaussian.to(x.device)
        z = self.make_bernoulli(shifted_gaussian)
        new_x = x * z
        new_y = x * (1-z)
        return new_x, new_y, L

    @staticmethod
    def make_bernoulli(z):
        return torch.clamp(z, 0.0, 1.0)

    def get_reg(self):
        return self.lam * torch.sum((torch.erf((self.mus / self.sigma) / math.sqrt(2)) - (torch.erf(((self.mus-1) / self.sigma)) / math.sqrt(2))) / 2)

    def get_gates(self):
        return self.make_bernoulli(self.mus)