import torch
from torch import nn
from torch import autograd
from torch.nn.functional import gumbel_softmax, one_hot
from torch.distributions import Categorical


class ControlVariate(torch.nn.Module):
    '''Neural control variate for RELAX'''
    def __init__(self, num_latents, hidden_size=100):
        super(ControlVariate, self).__init__()
        self.h1 = nn.Linear(num_latents, hidden_size)
        self.nonlin = nn.Tanh()
        self.out = nn.Linear(hidden_size, num_latents)

    def forward(self, z):
        z = self.h1(z)
        z = self.nonlin(z)
        z = self.out(z)
        return z


class RelaxGumbelSoftmax(autograd.Function):
    """Implements the forward and backward pass of GumbelSoftmax with the RELAX
    gradient estimator from [1]_

    References
    ----------
    .. [1] Grathwohl et al. (2018). Backpropagation through the Void:
    Optimizing control variates for black-box gradient estimation. In
    Proceedings of ICLR.
    """
    @staticmethod
    def forward(ctx, logits, temp, control_variate):
        """Forward pass of RELAX Gumbel-Softmax distribution

        Parameters
        ----------
        logits : torch.Tensor
            Unnormalized log probabilities.
        temp : torch.Tensor
            Gumbel-softmax temperature (shape=0).
        control_variate : torch.nn.Module
            Neural network used as a control variate in the RELAX method.

        Returns
        -------
        torch.Tensor
            Onehot encoded category sampled.
        """
        # Sample epsilon and build continuous surrogate random variable z
        # (gumbel softmax). torch.nn.functional.gumbel_softmax is not used
        # so that the epsilon disturbance can be stored for backward compute
        epsilon = torch.rand(size=logits.shape).to(logits.device)
        z = torch.softmax(
            (logits + torch.log(epsilon) - torch.log1p(-epsilon)) / temp,
            dim=0
        )
        # TODO: Check how to optimize the temp

        # Map z into categorical b using H
        # (straight-through trick used in torch.nn.functional.gumbel_softmax)
        b = one_hot(torch.argmax(z), num_classes=len(logits)) - z.detach() + z

        # Sample epsilon_tilde(for backward computation)
        epsilon_tilde = torch.rand(size=logits.shape).to(logits.device)

        # Save elements for the gradients computation
        # (It is necessary to remove previous gradients to avoid infite loops
        # when calling autograd.grad within backward)
        ctx._control_variate = control_variate
        ctx.save_for_backward(
            logits.detach(),
            temp.detach(),
            b.detach(),
        )
        ctx.epsilon = epsilon
        ctx.epsilon_tilde = epsilon_tilde

        # Return the onehot vector
        return b

    @staticmethod
    def backward(ctx, grad_output):
        """ Forward pass of RELAX Gumbel-Softmax distribution

        This backward method is greatly inspired by the official relax
        implementation:
        https://github.com/duvenaud/relax/blob/master/pytorch_toy.py

        Parameters
        ----------
        grad_output: torch.Tensor
            Gradients from the downstream wrt to th onehot encoded vector from
            the forward pass.

        Returns
        -------
        torch.Tensor
            Gradient of the onehot encoded vector wrt the logits.
        torch.Tensor
            Gradient of the onehot encoded vector wrt the temperature `temp`.
        torch.Tensor
            Gradient of the onehot encoded vector wrt the control variate
            (None).
        """
        # Recover saved variables
        logits, temp, b = ctx.saved_tensors

        # Enable gradients back from here forward (necessary)
        with torch.enable_grad():
            logits.requires_grad_()

            # Recompute z using epsilon and logits (because we had to drop
            # previous logits gradients)
            z = logits + torch.log(ctx.epsilon) - torch.log1p(-ctx.epsilon)
            z = torch.softmax(z / temp, dim=0)

            # Build z_tilde using logits, b and epsilon_tilde (from forward)
            epsilon_tilde = ctx.epsilon_tilde
            z_tilde = torch.empty_like(epsilon_tilde)
            v_b = (b * epsilon_tilde).sum()
            probs = logits.softmax(dim=0)  # probabilities
            z_tilde = -torch.log(
                -torch.log(epsilon_tilde) * (b + (1-b)/probs)
                - (1-b) * torch.log(v_b)
            )
            z_tilde = torch.softmax(z_tilde / temp, dim=0)

            # Compute log probability of sampling b
            # (Usually, the log probability of a onehot encoded categorical
            # distribution should be valued in R, but here we need to treat it
            # as an R^12 valued vector for dim compatibility. This could be an
            # issue...)
            log_p = torch.log(logits.softmax(0)) * b

            # Compute its gradient wrt the logits
            grad_logpb = autograd.grad(
                [log_p], [logits],
                grad_outputs=grad_output,
            )[0]

            # Compute the forward of the neural control variate and grad wrt
            # the logits
            cv_z = ctx._control_variate(z)
            grad_cv_z = autograd.grad(
                [cv_z], [logits], grad_outputs=grad_output,
                create_graph=True,
            )[0]

            # Same for z_tilde
            cv_z_tilde = ctx._control_variate(z_tilde)
            grad_cv_z_tilde = autograd.grad(
                [cv_z_tilde], [logits], grad_outputs=grad_output,
                create_graph=True,
            )[0]

            # Compute the RELAX logits gradient
            diff = b - cv_z_tilde
            grad_logits = diff * grad_logpb + grad_cv_z - grad_cv_z_tilde

            # Compute the RELAX control variable weights gradient
            var_loss = (grad_logits.mean(0) ** 2).mean()

        var_loss.backward()

        # Apply the chain rule with the incoming gradients and return
        return grad_logits, None, None
