import torch
import torch.nn.functional as F
import numpy as np

from models.utils import Base


class RoundStraightThrough(torch.autograd.Function):

    def __init__(self):
        super().__init__()

    @staticmethod
    def forward(ctx, input):
        rounded = torch.round(input, out=None)
        return rounded

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        return grad_input


_round_straightthrough = RoundStraightThrough().apply


def _stacked_sigmoid(x, temperature, n_approx=3):

    x_ = x - 0.5
    rounded = torch.round(x_)
    x_remainder = x_ - rounded

    size = x_.size()
    x_remainder = x_remainder.view(size + (1,))

    translation = torch.arange(n_approx) - n_approx // 2
    translation = translation.to(device=x.device, dtype=x.dtype)
    translation = translation.view([1] * len(size) + [len(translation)])
    out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)

    return out + rounded - (n_approx // 2)


class SmoothRound(Base):
    def __init__(self):
        self._temperature = None
        self._n_approx = None
        super().__init__()
        self.hard_round = None

    @property
    def temperature(self):
        return self._temperature

    @temperature.setter
    def temperature(self, value):
        self._temperature = value

        if self._temperature <= 0.05:
            self._n_approx = 1
        elif 0.05 < self._temperature < 0.13:
            self._n_approx = 3
        else:
            self._n_approx = 5

    def forward(self, x):
        assert self._temperature is not None
        assert self._n_approx is not None
        assert self.hard_round is not None

        if self.temperature <= 0.25:
            h = _stacked_sigmoid(x, self.temperature, n_approx=self._n_approx)
        else:
            h = x

        if self.hard_round:
            h = _round_straightthrough(h)

        return h


class StochasticRound(Base):
    def __init__(self):
        super().__init__()
        self.hard_round = None

    def forward(self, x):
        u = torch.rand_like(x)

        h = x + u - 0.5

        if self.hard_round:
            h = _round_straightthrough(h)

        return h


class BackRound(Base):

    def __init__(self, args, inverse_bin_width):
        """
        BackRound is an approximation to Round that allows for Backpropagation.

        Approximate the round function using a sum of translated sigmoids.
        The temperature determines how well the round function is approximated,
        i.e., a lower temperature corresponds to a better approximation, at
        the cost of more vanishing gradients.

        BackRound supports the following settings:
        * By setting hard to True and temperature > 0.25, BackRound
          reduces to a round function with a straight through gradient
          estimator
        * When using 0 < temperature <= 0.25 and hard = True, the
          output in the forward pass is equivalent to a round function, but the
          gradient is approximated by the gradient of a sum of sigmoids.
        * When using hard = False, the output is not constrained to integers.
        * When temperature > 0.25 and hard = False, BackRound reduces to
          the identity function.

        Arguments
        ---------
        temperature: float
            Temperature used for stacked sigmoid approximated. If temperature
            is greater than 0.25, the approximation reduces to the indentiy
            function.
        hard: bool
            If hard is True, a (hard) round is applied before returning. The
            gradient for this is approximated using the straight-through
            estimator.
        """
        super().__init__()
        self.inverse_bin_width = inverse_bin_width
        self.round_approx = args.round_approx

        if args.round_approx == 'smooth':
            self.round = SmoothRound()
        elif args.round_approx == 'stochastic':
            self.round = StochasticRound()
        else:
            raise ValueError

    def forward(self, x):
        if self.round_approx == 'smooth' or self.round_approx == 'stochastic':
            h = x * self.inverse_bin_width

            h = self.round(h)

            return h / self.inverse_bin_width

        else:
            raise ValueError
