import torch
import torch.nn.functional as F
import math


class MexicanHat(torch.nn.Module):
    def __init__(self, normalize=None):
        super().__init__()

        self.pi = torch.tensor(math.pi)

        self.width = torch.tensor(0.9)
        self.surround_scale = torch.tensor(2.0)
        self.ws2 = (self.width*self.surround_scale)**2
        self.delta = torch.sqrt( # so that y(x=0)=0 and y(x=delta)=1
            -4*torch.log(self.surround_scale)*self.ws2 / 
            (1 - self.surround_scale**2)
        )

        if normalize is not None and normalize not in ['softmax', 'sum']:
            raise ValueError("normalize must be 'softmax' or 'sum'")
        self.normalize = normalize

        self.softmax = torch.nn.Softmax(dim=1)
    
    def _normalize(self, y):
        if self.normalize == 'softmax':
            y = self.softmax(y)
        elif self.normalize == 'sum':
            _min = y.min(dim=1)
            y += torch.abs(_min)[..., None]
            y /= y.sum(dim=1)[..., None]

        return y

    def _compute_dog(self, x):
        E = torch.exp(
            -(x-self.delta)**2 /
            (2* self.width**2)
        )
        I = torch.exp(
            -(x-self.delta)**2 /
            (2* self.ws2)
        )

        y = (
            E / (2*self.pi* self.width**2) -
            I / (2*self.pi*self.ws2)
        ) * ( # normalize to have y=1 when x=1
            2*self.pi*self.ws2 /
            (self.surround_scale**2 -1)
        )

        if y.isnan().any():
            raise ValueError("MexicanHat activation function produced NaN values.")

        return torch.where(x<=self.delta, y, torch.ones_like(y))

    def forward(self, x):
        y = self._compute_dog(x)
        return self._normalize(y)


class MexicanHatStandard(torch.nn.Module):
    def __init__(self, normalize=None):
        super().__init__()
     
        if normalize is not None and normalize not in ['softmax', 'sum']:
            raise ValueError("normalize must be 'softmax' or 'sum'")
        self.normalize = normalize

        self.softmax = torch.nn.Softmax(dim=1)
    
    def _normalize(self, y):
        if self.normalize == 'softmax':
            y = self.softmax(y)
        elif self.normalize == 'sum':
            _min = y.min(dim=1)
            y += torch.abs(_min)[..., None]
            y /= y.sum(dim=1)[..., None]

        return y

    def forward(self, x):
        y = (1 - x**2)*torch.exp(-x**2 / 2)
        return self._normalize(y)


class HardSoftmax(torch.nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim

    def forward(self, x):
        softmaxed_x = torch.nn.Softmax(dim=1)(x)
        return torch.mul(
            softmaxed_x,
            (softmaxed_x >= 1/(0.5*self.latent_dim)).float()
        )

class HardSigmoid(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.clamp(0.2*x +0.5, min=0.0, max=1.0)

class _STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return (input > 0).float()

    @staticmethod
    def backward(ctx, grad_output):
        return F.hardtanh(grad_output)

class StraightThroughEstimator(torch.nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()

    def forward(self, x):
        x = _STEFunction.apply(x)
        return x