import torch
import torch.nn as nn
import torch.nn.functional as F


class TrainableSigmoid(nn.Module):
    def __init__(self, alpha):
        super(TrainableSigmoid, self).__init__()
        self.alpha = nn.Parameter(alpha)

    def forward(self, x):
        return torch.sigmoid(self.alpha * x)


class TrainableReLU(nn.Module):
    def __init__(self, alpha):
        super(TrainableReLU, self).__init__()
        self.alpha = nn.Parameter(alpha)

    def forward(self, x):
        return torch.clamp(torch.relu(self.alpha * x + 0.5), max=1.0, min=0.0)


class NoisySpike(nn.Module):
    def __init__(self, p=1, sig='clamp', spike=None):
        super(NoisySpike, self).__init__()
        self.p = p
        if sig == 'sigmoid':
            self.sig = TrainableSigmoid(torch.tensor(1.0))
        elif sig == 'clamp':
            self.sig = TrainableReLU(torch.tensor(1.0))
        self.mask = None

    def create_mask(self, x: torch.Tensor):
        return F.dropout(torch.ones_like(x.data), self.p, training=True)

    def forward(self, x):
        if self.training:
            if self.mask is None:
                self.mask = self.create_mask(x)
            noise = ((x >= 0).float() - self.sig(x)) * self.mask
            return self.sig(x) + noise.detach()
        return (x >= 0).float()

    def reset_mask(self):
        self.mask = None


class MutiStepNoisyRateScheduler:

    def __init__(self, init_p=1, reduce_ratio=0.9, milestones=[0.3, 0.7, 0.9, 0.95], num_epoch=100, start_epoch=0):
        self.reduce_ratio = reduce_ratio
        self.p = init_p
        self.milestones = [int(m * num_epoch) for m in milestones]
        self.num_epoch = num_epoch
        self.start_epoch = start_epoch

    def __call__(self, epoch, model):
        for one in self.milestones:
            if one + self.start_epoch == epoch:
                self.p *= self.reduce_ratio
                print('change noise rate as ' + str(self.p))
                if isinstance(model, nn.DataParallel):
                    model.module.set_noisy_rate(self.p)
                else:
                    model.set_noisy_rate(self.p)
                break
