# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn


class LayerSelect(nn.Module):
    """Compute samples (from a Gumbel-Sigmoid distribution) which is used as
    either (soft) weighting or (hard) selection of residual connection.
    https://arxiv.org/abs/2009.13102
    """
    def __init__(self, num_layers, num_logits, soft_select=False, sampling_tau=5.):
        super(LayerSelect, self).__init__()
        self.layer_logits = torch.nn.Parameter(
            torch.Tensor(num_logits, num_layers),
            requires_grad=True,
        )
        self.hard_select = not soft_select
        self.tau = sampling_tau
        self.detach_grad = False
        self.layer_samples = [None] * num_logits

    def sample(self, logit_idx):
        """To leverage the efficiency of distributed training, samples for all
        layers are computed at once for each logit_idx. Logits are parameters
        learnt independent of each other.

        Args:
            logit_idx: The index of logit parameters used for sampling.
        """
        assert logit_idx is not None
        self.samples = self._gumbel_sigmoid(
            self.layer_logits[logit_idx, :].detach()
            if self.detach_grad
            else self.layer_logits[logit_idx, :],
            dim=-1,
            tau=self.tau,
            hard=self.hard_select,
        )
        self.layer_samples[logit_idx] = self.samples

    def forward(self, i):
        sample = self.samples[i]
        return sample

    def _gumbel_sigmoid(
        self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
    ):
        # ~Gumbel(0,1)
        gumbels1 = (
            -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
            .exponential_()
            .log()
        )
        gumbels2 = (
            -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
            .exponential_()
            .log()
        )
        # Difference of two gumbels because we apply a sigmoid
        gumbels1 = (logits + gumbels1 - gumbels2) / tau
        y_soft = gumbels1.sigmoid()
        if hard:
            # Straight through.
            y_hard = torch.zeros_like(
                logits, memory_format=torch.legacy_contiguous_format
            ).masked_fill(y_soft > threshold, 1.0)
            ret = y_hard - y_soft.detach() + y_soft
        else:
            # Reparametrization trick.
            ret = y_soft
        return ret
