import torch
import torch.nn.functional as F

from fairseq import utils

class SwitchGate():
    def __init__(self, num_experts, compute_balance_loss=False, jitter=False, jitter_eps=0.1,  onnx_trace=False):
        self.num_experts = num_experts
        self.compute_balance_loss = compute_balance_loss
        self.onnx_trace = onnx_trace
        self.jitter = jitter
        self.jitter_eps = jitter_eps

    def one_sample_softmax_st(self, logits, onnx_trace=False, training=True):
            
        mask_logits_threshold, max_ind = logits.max(dim=-1, keepdim=True)
        factor = logits.abs().clamp(min=mask_logits_threshold)
        mask_logits_threshold = (
            (mask_logits_threshold - logits) / factor
        ) > (2 * self.jitter_eps)
        
        logits_w_t = logits.masked_fill_(mask_logits_threshold, float('-inf'))
        if training:
            gumbels = (
                logits_w_t-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
            )  # ~Gumbel(0,1)
            p = utils.softmax(logits_w_t.float(), dim=-1, onnx_trace=onnx_trace).type_as(logits)
            detached_sample = (gumbels).max(dim=-1)[1]
            
            multiplier_0 = p.gather(dim=1, index=detached_sample.unsqueeze(1)).type_as(logits).view(-1, 1)
            multiplier_1 = multiplier_0 - multiplier_0.detach() * 0.5 
            
            multiplier_1_mask = (detached_sample.type_as(logits).uniform_(0., 1.) > 0.5).unsqueeze(1)
            multiplier_0_mask = detached_sample.unsqueeze(1) == max_ind
            multiplier_1_mask = torch.logical_or(multiplier_1_mask, multiplier_0_mask)
            
            multiplier = multiplier_1_mask * multiplier_0 + (~multiplier_1_mask) * multiplier_1
            multiplier = multiplier * 2 - multiplier.detach()
        else:
            p = utils.softmax(logits_w_t.float(), dim=-1, onnx_trace=onnx_trace).type_as(logits)
            detached_sample = max_ind.squeeze(-1)
        
            multiplier = p.gather(dim=1, index=detached_sample.unsqueeze(1)).type_as(logits).view(-1, 1)
            
        return p.type_as(logits), detached_sample, multiplier

    def __call__(self, x, gating_network, fc1, fc2, activation_fn, activation_dropout_module, training=True):
        seq_len, bsz, dim = x.shape
        x = x.view(-1, dim)
        logits_gate = gating_network(x)

        assert self.jitter 
        
        prob_gate, sample, multiplier = self.one_sample_softmax_st(logits_gate, self.onnx_trace, training)
        
        order = sample.argsort(0)
        num_tokens = F.one_hot(sample, self.num_experts).gt(0).sum(0)
        x = x[order]  # reorder according to expert number
        x = x.split(num_tokens.tolist(), dim=0)  # a list of length self.num_experts

        # compute the load balancing loss
        balance_loss = None
        if self.compute_balance_loss:
            P = prob_gate.mean(0)
            temp = num_tokens.float()
            f = temp / (temp.sum(0, keepdim=True) + 1e-6)
            balance_loss = self.num_experts * torch.sum(P * f)

        def forward_fc(input_x, expert_idx):
            if input_x.numel() > 0:
                input_x = activation_fn(fc1[expert_idx](input_x))
                input_x = activation_dropout_module(input_x)
                input_x = fc2[expert_idx](input_x)
            return input_x

        x = [forward_fc(x[i], i) for i in range(self.num_experts)]
        x = torch.vstack(x)
        # x = x[order.argsort(0)] * prob_gate.view(-1, 1) * multiplier  # restore original order
        x = x[order.argsort(0)] * multiplier # restore original order
        x = x.view(seq_len, bsz, dim)

        return x, num_tokens, balance_loss
