# Sparsely-Gated Mixture-of-Experts Layers.
# See "Outrageously Large Neural Networks"
# https://arxiv.org/abs/1701.06538
#
# Author: David Rau
#
# The code is based on the TensorFlow implementation:
# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py


import math
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

from .parallel_experts import ParallelExperts

# @torch.jit.script
def compute_gating(k: int, probs: torch.Tensor, top_k_gates: torch.Tensor, top_k_indices: torch.Tensor):
    # print('k:', k) # 12
    # print('probs.shape:', probs.shape) # (B*T, E)
    # print('top_k_gates.shape:', top_k_gates.shape) # (B*T, K)
    # print('top_k_indices.shape:', top_k_indices.shape) # (B*T, K)

    zeros = torch.zeros_like(probs) # (B*T, E)
    gates = zeros.scatter(1, top_k_indices, top_k_gates) # (B*T, E)
    top_k_gates = top_k_gates.flatten() # (B*T*K,)
    top_k_experts = top_k_indices.flatten() # (B*T*K,)
    # print('top_k_indices:', top_k_indices)
    '''
    [[ 0,  5, 21,  ...,  7,  6, 11],
    [ 0,  5, 21,  ...,  1,  6, 11],
    [ 0,  5, 21,  ...,  7,  6, 13],
    ...,
    [ 0,  5, 21,  ...,  7,  6, 13],
    [ 0, 21,  5,  ...,  7,  6, 20],
    [ 0, 21,  5,  ...,  6,  7, 20]]
    '''
    # print('top_k_gates.nonzero().shape:', top_k_gates.nonzero().shape) # (B*T*K, 1)
    nonzeros = top_k_gates.nonzero().squeeze(-1) # (B*T*K,), value is in 0~(B*T*K-1)
    # print('nonzeros:', nonzeros) # [0, 1, 2, 3, ..., 14181, 14182, 14183]
    top_k_experts_nonzero = top_k_experts[nonzeros] # (B*T*K,)
    _, _index_sorted_experts = top_k_experts_nonzero.sort(0)
    # print('_index_sorted_experts:', _index_sorted_experts) # [0, 12, 24, ..., 14152, 14164, 14176]
    # print('_index_sorted_experts.shape:', _index_sorted_experts.shape) # (B*T*K,)
    expert_size = (gates > 0).long().sum(0)
    # print('expert_size:', expert_size) # [1182, 1179, 0, 0, 1101, 1182, 1046, 1113, 316, 0, 64, 338, 0, 308, 0, 1182, 0, 1182, 1181, 15, 431, 1182, 1182, 0]
    # print('expert_size.shape:', expert_size.shape) # (24, )
    index_sorted_experts = nonzeros[_index_sorted_experts]
    # print('index_sorted_experts:', index_sorted_experts) # [0, 12, 24, ..., 14152, 14164, 14176], for the expert activated for all B*T tokens, the interval is 12
    # print('index_sorted_experts.shape:', index_sorted_experts.shape) # (B*T*K,)
    batch_index = index_sorted_experts.div(k, rounding_mode='trunc') # value is in 0~(B*T-1)
    # print('batch_index.shape:', batch_index.shape) # (B*T*K,)
    # print('batch_index:', batch_index) # [0, 1, 2, ..., 1179, 1180, 1181]
    # print('index_sorted_experts.shape:', index_sorted_experts.shape) # (B*T*K,)
    # print('top_k_gates.shape:', top_k_gates.shape) # (B*T*K,)
    batch_gates = top_k_gates[index_sorted_experts]
    # print('batch_gates.shape:', batch_gates.shape) # (B*T*K,)
    return batch_gates, batch_index, expert_size, gates, index_sorted_experts


class MoE(nn.Module):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self, input_size, head_size, num_experts, k,
                 cvloss=0, switchloss=0, zloss=0,
                 bias=False, gating_activation=None,
                 activation=None, noisy_gating=True, usage_mem = 10000,
                 acc_aux_loss=False):
        super(MoE, self).__init__()

        self.noisy_gating = noisy_gating
        self.num_experts = num_experts
        self.input_size = input_size
        self.head_size = head_size
        self.bias = bias
        self.experts = ParallelExperts(num_experts, input_size, head_size, bias)
        self.output_experts = ParallelExperts(num_experts, head_size, input_size, bias)
        self.k = min(k, self.num_experts)
        self.cvloss = cvloss
        self.switchloss = switchloss
        self.zloss = zloss
        self.activation = activation
        # self.usage = np.random.randint(num_experts, size=(usage_mem, k))
        # self.cur = 0

        self.acc_aux_loss = acc_aux_loss
        if self.acc_aux_loss:
            self.init_aux_statistics()

        if True:
            if gating_activation is None:
                gating_activation = nn.ReLU()
            self.f_gate = nn.Sequential(
                # nn.Linear(input_size, input_size),
                # gating_activation,
                nn.Linear(input_size,
                          2 * num_experts if noisy_gating else num_experts,
                          bias=False)
            )
            nn.init.zeros_(self.f_gate[-1].weight)
        else:
            self.f_gate = nn.Linear(input_size, num_experts, bias=False)
            nn.init.zeros_(self.f_gate.weight)


    def extra_repr(self):
        return 'k={}, cvloss={}, switchloss={}, zloss={}, noisy_gating={}'.format(
            self.k, self.cvloss, self.switchloss, self.zloss, self.noisy_gating)

    def cv_squared(self, x):
        """The squared coefficient of variation of a sample.
        Useful as a loss to encourage a positive distribution to be more uniform.
        Epsilons added for numerical stability.
        Returns 0 for an empty Tensor.
        Args:
        x: a `Tensor`.
        Returns:
        a `Scalar`.
        """
        eps = 1e-10
        # if only num_experts = 1

        if x.shape[0] == 1:
            return 0
        return x.float().var() / (x.float().mean()**2 + eps)

    def init_aux_statistics(self):
        self.acc_probs = 0.
        self.acc_gates = 0.
        self.acc_freq = 0.
        self.acc_lsesq = 0.
        self.acc_lsesq_count = 0.

        # self._gates = []
        # self._probs = []
        # self._logits = []
        # self._expert_sizes = []

    def update_aux_statistics(self, logits, probs, gates):
        lsesq = torch.log(torch.exp(logits).sum(dim=1) + 0.000001) ** 2
        self.acc_probs = self.acc_probs + probs.sum(0)
        self.acc_gates = self.acc_gates + gates.sum(0)
        self.acc_freq = self.acc_freq + (gates > 0).float().sum(0)
        self.acc_lsesq = self.acc_lsesq + lsesq.sum()
        self.acc_lsesq_count = self.acc_lsesq_count + lsesq.size(0)

    def get_aux_loss_and_clear(self):
        cvloss = self.cv_squared(F.normalize(self.acc_gates, p=1, dim=0))
        # cvloss = self.acc_gates.mean() / 10000.0
        switchloss = (F.normalize(self.acc_probs, p=1, dim=0) *
                      F.normalize(self.acc_freq, p=1, dim=0)).sum() * self.num_experts
        zloss = self.acc_lsesq / (self.acc_lsesq_count)
        # loss = (self.cvloss * cvloss)
        loss = (self.cvloss * cvloss +
                self.switchloss * switchloss +
                self.zloss * zloss)

        # print("cvloss")
        # true_cvloss = self.compute_cvloss(torch.cat(self._gates, dim=0))
        # print(self.cvloss, cvloss, true_cvloss)

        # print("switchloss")
        # cat_probs = torch.cat(self._probs, dim=0)
        # true_switchloss = self.compute_switchloss(cat_probs, sum(self._expert_sizes))
        # print(self.switchloss, switchloss, true_switchloss)

        # print("zloss")
        # true_zloss = self.compute_zloss(torch.cat(self._logits, dim=0))
        # print(self.zloss, zloss, true_zloss)

        # assert torch.allclose(cvloss, true_cvloss)
        # assert torch.allclose(switchloss, true_switchloss)
        # assert torch.allclose(zloss, true_zloss)

        self.init_aux_statistics()
        return loss

    # def compute_topk_loss(self, probs):


    def compute_cvloss(self, probs):
        return self.cv_squared(F.normalize(probs.sum(0), p=1, dim=0))

    def compute_switchloss(self, probs, freqs):
        loss = F.normalize(probs.sum(0), p=1, dim=0) * \
               F.normalize(freqs.float(), p=1, dim=0)
        return loss.sum() * self.num_experts

    def compute_zloss(self, logits):
        zloss = torch.mean(torch.log(torch.exp(logits).sum(dim=1)) ** 2)
        return zloss

    def top_k_gating(self, x, skip_mask=None, sample_topk=0, noise_epsilon=1e-2):
        """Noisy top-k gating.
          See paper: https://arxiv.org/abs/1701.06538.
          Args:
            x: input Tensor with shape [batch_size, input_size]
            train: a boolean - we only add noise at training time.
            noise_epsilon: a float
          Returns:
            gates: a Tensor with shape [batch_size, num_experts]
            load: a Tensor with shape [num_experts]
        """
        clean_logits = self.f_gate(x)
        if self.noisy_gating:
            clean_logits, raw_noise_stddev = clean_logits.chunk(2, dim=-1)
            noise_stddev = F.softplus(raw_noise_stddev) + noise_epsilon
            eps = torch.randn_like(clean_logits)
            noisy_logits = clean_logits + eps * noise_stddev
            logits = noisy_logits
        elif self.noisy_gating:
            logits, _ = clean_logits.chunk(2, dim=-1)
        else:
            logits = clean_logits

        probs = torch.softmax(logits, dim=1)
        if skip_mask is not None:
            probs = torch.masked_fill(probs, skip_mask, 0)

        assert sample_topk == 0
        if self.training and (sample_topk > 0):
            # top_k_indices = torch.multinomial(probs + 1e-6, self.k)
            # top_k_gates = torch.gather(probs, 1, top_k_indices)
            assert sample_topk <= self.k

            _, top_km1_indices = probs.topk(self.k - sample_topk, dim=1)
            masked_probs = probs + 1e-6
            masked_probs[torch.arange(probs.size(0)).unsqueeze(
                1), top_km1_indices] = 0
            k_indices = torch.multinomial(masked_probs, sample_topk)
            top_k_indices = torch.cat([top_km1_indices, k_indices], dim=-1)
            top_k_gates = torch.gather(probs, 1, top_k_indices)
        else:
            top_k_gates, top_k_indices = probs.topk(self.k, dim=1)

       # top_k_indecis: [batch, K]


        top_k_gates = top_k_gates

        batch_gates, batch_index, expert_size, gates, index_sorted_experts = compute_gating(self.k, probs, top_k_gates, top_k_indices)

        # print('probs: ', probs)
        # print('top_k_gates: ', top_k_gates)
        # print('top_k_indices: ', top_k_indices)
        # print('expert_size: ', expert_size)
        self.expert_size = expert_size
        self.index_sorted_experts = index_sorted_experts
        self.batch_index = batch_index
        self.batch_gates = batch_gates

        loss = 0.
        if self.acc_aux_loss:
            # if self.training:
            self.update_aux_statistics(logits, probs, gates)
        else:
            loss += self.cvloss * self.compute_cvloss(gates)
            loss += self.switchloss * \
                self.compute_switchloss(probs, self.expert_size)
            loss += self.zloss * self.compute_zloss(logits)
        return loss

    def forward(self, x, skip_mask=None, sample_topk=0, multiply_by_gates=True):
        # y_ = self.forward_(x, skip_mask, sample_topk, multiply_by_gates)
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        loss = self.top_k_gating(x, skip_mask, sample_topk=sample_topk)

        expert_inputs = x[self.batch_index]
        h = self.experts(expert_inputs, self.expert_size)
        h = self.activation(h)
        expert_outputs = self.output_experts(h, self.expert_size)

        if multiply_by_gates:
            expert_outputs = expert_outputs * self.batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), 
            dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.batch_index, expert_outputs)
        y = y.view(bsz, length, self.input_size)
        # assert torch.allclose(y, y_)
        return y, loss

    def forward_(self, x, skip_mask=None, sample_topk=0, multiply_by_gates=True):
        # FOR DEBUGGING: naive forward
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)

        logits = self.f_gate(x)
        probs = torch.softmax(logits, dim=1)

        loss = self.top_k_gating(x, skip_mask, sample_topk=sample_topk)
        assert not self.bias
        hs = [torch.einsum('li,ij->lj', x, self.experts.w[i]) for i in range(self.num_experts)]
        hs = [self.activation(h) for h in hs]
        expert_outputs = [
            torch.einsum('li,ij->lj', hs[i], self.output_experts.w[i]) for i in range(self.num_experts)
        ]
        y = sum(probs[..., i][..., None] * expert_outputs[i] for i in range(self.num_experts))
        y = y.view(bsz, length, self.input_size)

        # if multiply_by_gates:
        #     expert_outputs = expert_outputs * self.batch_gates[:, None]
        return y

    def map(self, x, skip_mask=None, sample_topk=0):
        """Args:
        x: tensor shape [batch_size, input_size]
        train: a boolean scalar.
        loss_coef: a scalar - multiplier on load-balancing losses
        Returns:
        y: a tensor with shape [batch_size, output_size].
        extra_training_loss: a scalar.  This should be added into the overall
        training loss of the model.  The backpropagation of this loss
        encourages all experts to be approximately equally used across a batch.
        """
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)
        loss = self.top_k_gating(x, skip_mask, sample_topk=sample_topk)

        expert_inputs = x[self.batch_index]

        expert_outputs = self.experts(expert_inputs, self.expert_size) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

        zeros = torch.zeros((bsz * length * self.k, self.head_size), 
            dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.index_sorted_experts, expert_outputs)
        y = y.view(bsz, length, self.k, -1)
        return y, loss

    def reduce(self, x, multiply_by_gates=True):
        bsz, length, k, emb_size = x.size()
        x = x.view(-1, emb_size)

        expert_inputs = x[self.index_sorted_experts]

        expert_outputs = self.output_experts(expert_inputs, self.expert_size) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

        if multiply_by_gates:
            expert_outputs = expert_outputs * self.batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.batch_index, expert_outputs)
        y = y.view(bsz, length, self.input_size)
        return y


class TaskMoE(MoE):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self,  input_size, head_size, num_experts, k, w_MI=0, limit_k=0, w_topk_loss=0.0, task_num=9, noisy_gating=True, gating_activation=None, **kwargs):
        self.task_num = task_num
        self.w_topk_loss = w_topk_loss
        self.w_MI = w_MI

        self.limit_k = max(k, limit_k)

        super(TaskMoE, self).__init__(input_size, head_size, num_experts, k, noisy_gating=noisy_gating, gating_activation=gating_activation, **kwargs)
        
        if gating_activation is None:
            gating_activation = nn.ReLU()

        self.f_gate = nn.ModuleList([nn.Sequential(
                                        # nn.Linear(input_size, input_size),
                                        # gating_activation,
                                        nn.Linear(input_size,
                                                  2 * num_experts if noisy_gating else num_experts,
                                                  bias=False)
                                    ) for i in range(task_num)])
        for i in range(task_num):
            nn.init.zeros_(self.f_gate[i][-1].weight)
    
    def init_aux_statistics(self, clear=True):
        self.acc_probs = 0.
        self.acc_gates = 0.
        self.acc_freq = 0.
        self.acc_lsesq = 0.
        self.acc_lsesq_count = 0.

        if clear:
            self.task_gate_freq = [0] * self.task_num
            self.topk_acc_probs = 0.

        self.MI_task_gate = torch.zeros(self.task_num, self.num_experts).cuda()

    def update_aux_statistics(self, logits, probs, gates, task_bh):
        lsesq = torch.log(torch.exp(logits).sum(dim=1) + 0.0001) ** 2
        self.acc_probs = self.acc_probs + probs.sum(0)
        self.acc_gates = self.acc_gates + gates.sum(0)
        self.acc_freq = self.acc_freq + (gates > 0).float().sum(0)
        self.acc_lsesq = self.acc_lsesq + lsesq.sum()
        self.acc_lsesq_count = self.acc_lsesq_count + lsesq.size(0)

        self.topk_acc_probs = self.topk_acc_probs + probs.mean(0)

        self.task_gate_freq[task_bh] = self.task_gate_freq[task_bh]*0.95 + ((gates > 0).float().sum(0)).detach()*0.05

        # self.MI_task_gate[task_bh] = self.MI_task_gate[task_bh] + gates.sum(0)
        self.MI_task_gate[task_bh] = self.MI_task_gate[task_bh] + probs.sum(0)

    def get_topk_loss_and_clear(self):
        top_k_probs, top_k_indices = self.topk_acc_probs.topk(self.limit_k, dim=0)
        zeros = torch.zeros_like(self.topk_acc_probs)
        gates = zeros.scatter(0, top_k_indices, top_k_probs)
        topk_loss = ((self.topk_acc_probs - gates) * (self.topk_acc_probs - gates)).sum()

        self.topk_acc_probs = 0.
        return topk_loss * self.w_topk_loss # 0.004 * 12 * 2 = 0.09

    def get_aux_loss_and_clear(self):
        '''
            acc_gates: sum of topk soft score
            acc_freq: the number of being chosen
            acc_probs: sum of probs (probs = softmax(score))
        '''

        cvloss = self.cv_squared(F.normalize(self.acc_gates, p=1, dim=0))
        switchloss = (F.normalize(self.acc_probs, p=1, dim=0) *
                      F.normalize(self.acc_freq, p=1, dim=0)).sum() * self.num_experts
        zloss = self.acc_lsesq / (self.acc_lsesq_count)

        tot = self.acc_freq.sum() / self.k
        self.MI_task_gate = self.MI_task_gate / (tot+0.0001)
        P_TI = torch.sum(self.MI_task_gate, dim=1, keepdim=True) + 0.0001
        P_EI = torch.sum(self.MI_task_gate, dim=0, keepdim=True) + 0.0001

        MI_loss = -(self.MI_task_gate * torch.log(self.MI_task_gate / P_TI / P_EI + 0.0001)).sum()
        
        loss = (self.cvloss * cvloss +
                self.switchloss * switchloss +
                self.zloss * zloss +
                self.w_MI * MI_loss
                )

        self.init_aux_statistics(clear=False)
        return loss

    def top_k_gating(self, x, task_bh, skip_mask=None, sample_topk=0, noise_epsilon=1e-2):
        """Noisy top-k gating.
          See paper: https://arxiv.org/abs/1701.06538.
          Args:
            x: input Tensor with shape [batch_size, input_size]
            train: a boolean - we only add noise at training time.
            noise_epsilon: a float
          Returns:
            gates: a Tensor with shape [batch_size, num_experts]
            load: a Tensor with shape [num_experts]
        """
        # print('x.shape:', x.shape) # (B*T, C)
        clean_logits = self.f_gate[task_bh](x)
        # print('clean_logits.shape:', clean_logits.shape) # (B*T, E = total_experts = 24)

        # print('self.noisy_gating:', self.noisy_gating) # False
        
        if self.noisy_gating and self.training:
        # if self.noisy_gating:
            clean_logits, raw_noise_stddev = clean_logits.chunk(2, dim=-1)
            noise_stddev = F.softplus(raw_noise_stddev) + noise_epsilon
            eps = torch.randn_like(clean_logits)
            noisy_logits = clean_logits + eps * noise_stddev
            logits = noisy_logits
        elif self.noisy_gating:
            logits, _ = clean_logits.chunk(2, dim=-1)
        else:
            logits = clean_logits

        probs = torch.softmax(logits, dim=1) + 1e-4
        # print('probs.shape:', probs.shape) # (B*T, E)

        if skip_mask is not None: # pass
            probs = torch.masked_fill(probs, skip_mask, 0)

        if self.training and (sample_topk > 0): # pass
            # top_k_indices = torch.multinomial(probs + 1e-6, self.k)
            # top_k_gates = torch.gather(probs, 1, top_k_indices)
            assert sample_topk <= self.k

            _, top_km1_indices = probs.topk(self.k - sample_topk, dim=1)
            masked_probs = probs + 1e-6
            masked_probs[torch.arange(probs.size(0)).unsqueeze(
                1), top_km1_indices] = 0
            k_indices = torch.multinomial(masked_probs, sample_topk)
            top_k_indices = torch.cat([top_km1_indices, k_indices], dim=-1)
            top_k_gates = torch.gather(probs, 1, top_k_indices)
        else: # enter
            top_k_gates, top_k_indices = probs.topk(self.k, dim=1)
            # print('self.k:', self.k) # 12
            # print('top_k_gates.shape:', top_k_gates.shape) # (B*T, K)
            # print('top_k_indices.shape:', top_k_indices.shape) # (B*T, K)

       # top_k_indecis: [batch, K]
       
        top_k_gates = top_k_gates

        batch_gates, batch_index, expert_size, gates, index_sorted_experts = compute_gating(self.k, probs, top_k_gates, top_k_indices)

        # print('expert_size.shape:', expert_size.shape) # (24,)
        # print('index_sorted_experts.shape:', index_sorted_experts.shape) # (B*T*K,)
        # print('batch_index.shape:', batch_index.shape) # (B*T*K,)
        # print('batch_gates.shape:', batch_gates.shape) # (B*T*K,)

        self.expert_size = expert_size
        self.index_sorted_experts = index_sorted_experts
        self.batch_index = batch_index
        self.batch_gates = batch_gates

        loss = 0.
        if self.acc_aux_loss: # True
            self.update_aux_statistics(logits, probs, gates, task_bh)
        else:
            loss += self.cvloss * self.compute_cvloss(gates)
            loss += self.switchloss * self.compute_switchloss(probs, self.expert_size)
            loss += self.zloss * self.compute_zloss(logits)
        
        # print('self.expert_size.shape:', self.expert_size.shape) # (24,)
        # print('self.index_sorted_experts.shape:', self.index_sorted_experts.shape) # (B*T*K,)
        # print('self.batch_index.shape:', self.batch_index.shape) # (B*T*K,)
        # print('self.batch_gates.shape:', self.batch_gates.shape) # (B*T*K,)

        return loss

    def forward(self, x, task_bh, skip_mask=None, sample_topk=0, multiply_by_gates=True):
        # print('x.shape:', x.shape) # (B, 197, 768)
        # print('task_bh:', task_bh) # 0
        # print('skip_mask:', skip_mask) # None
        # print('sample_topk:', sample_topk) # 0
        # print('multiply_by_gates:', multiply_by_gates) # True

        # y_ = self.forward_(x, skip_mask, sample_topk, multiply_by_gates)
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size) # (B*T, C)
        if skip_mask is not None: # pass
            skip_mask = skip_mask.view(-1, 1)

        loss = self.top_k_gating(x, task_bh, skip_mask,  sample_topk=sample_topk)

        # print('x.shape:', x.shape) # (B*T, C)
        # print('self.batch_index.shape:', self.batch_index.shape) # (B*T*K,)
        expert_inputs = x[self.batch_index]
        # print('expert_inputs.shape:', expert_inputs.shape) # (B*T*K, C)
        # Note that self.expert_size.sum() = B*T*K
        h = self.experts(expert_inputs, self.expert_size) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # print('h.shape:', h.shape) # (B*T*K, C)
        h = self.activation(h) # GELU
        # print('h.shape:', h.shape) # (B*T*K, C)
        expert_outputs = self.output_experts(h, self.expert_size) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # print('expert_outputs.shape:', expert_outputs.shape) # (B*T*K, C)
        # print('multiply_by_gates:', multiply_by_gates) # True
        if multiply_by_gates: # enter
            expert_outputs = expert_outputs * self.batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
        # print('zeros.shape:', zeros.shape) # (B*T, C)
        # print('self.batch_index.shape:', self.batch_index.shape) # (B*T*K,)
        # print('expert_outputs.shape:', expert_outputs.shape) # (B*T*K, C)
        y = zeros.index_add(0, self.batch_index, expert_outputs)
        # print('y.shape:', y.shape) # (B*T, C)
        y = y.view(bsz, length, self.input_size)
        # assert torch.allclose(y, y_)

        # print('y.shape:', y.shape) # (B, T, C)

        return y, loss

    def forward_(self, x, task_bh, skip_mask=None, sample_topk=0, multiply_by_gates=True):
        # FOR DEBUGGING: naive forward
        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None:
            skip_mask = skip_mask.view(-1, 1)

        logits = self.f_gate[task_bh](x)
        probs = torch.softmax(logits, dim=1)

        loss = self.top_k_gating(x, skip_mask, task_bh, sample_topk=sample_topk)
        assert not self.bias
        hs = [torch.einsum('li,ij->lj', x, self.experts.w[i]) for i in range(self.num_experts)]
        hs = [self.activation(h) for h in hs]
        expert_outputs = [
            torch.einsum('li,ij->lj', hs[i], self.output_experts.w[i]) for i in range(self.num_experts)
        ]
        y = sum(probs[..., i][..., None] * expert_outputs[i] for i in range(self.num_experts))
        y = y.view(bsz, length, self.input_size)

        # if multiply_by_gates:
        #     expert_outputs = expert_outputs * self.batch_gates[:, None]
        return y


    def map(self, x, task_bh, skip_mask=None, sample_topk=0):
        """Args:
        x: tensor shape [batch_size, input_size]
        train: a boolean scalar.
        loss_coef: a scalar - multiplier on load-balancing losses
        Returns:
        y: a tensor with shape [batch_size, output_size].
        extra_training_loss: a scalar.  This should be added into the overall
        training loss of the model.  The backpropagation of this loss
        encourages all experts to be approximately equally used across a batch.
        """
        # print('x.shape:', x.shape) # (B, 197, 768)
        # print('task_bh:', task_bh) # 0
        # print('skip_mask:', skip_mask) # None
        # print('sample_topk:', sample_topk) # 0

        bsz, length, emb_size = x.size()
        x = x.reshape(-1, emb_size)
        if skip_mask is not None: # pass
            skip_mask = skip_mask.view(-1, 1)
        
        loss = self.top_k_gating(x, task_bh, skip_mask,  sample_topk=sample_topk)

        expert_inputs = x[self.batch_index] # (B*T*K, 768)
        expert_outputs = self.experts(expert_inputs, self.expert_size) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # print('expert_outputs.shape:', expert_outputs.shape) # (B*T*K, 128)

        # compared to forward, map function lack the procedure below
        '''
        h = self.activation(expert_outputs) # GELU
        expert_outputs = self.output_experts(h, self.expert_size)
        if multiply_by_gates: # enter
            expert_outputs = expert_outputs * self.batch_gates[:, None]
        '''

        zeros = torch.zeros((bsz * length * self.k, self.head_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.index_sorted_experts, expert_outputs)
        y = y.view(bsz, length, self.k, -1)
        # print('y.shape:', y.shape) # (B, T, K, 128)
        return y, loss

    def reduce(self, x, multiply_by_gates=True):
        # print('x.shape:', x.shape) # (B, T, K, 128)
        # print('multiply_by_gates:', multiply_by_gates) # True
        bsz, length, k, emb_size = x.size()
        x = x.view(-1, emb_size) # (B*T*K, 128)

        expert_inputs = x[self.index_sorted_experts] # (B*T*K, 128)
        expert_outputs = self.output_experts(expert_inputs, self.expert_size) #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        # print('expert_outputs.shape:', expert_outputs.shape) # (B*T*K, 768)

        if multiply_by_gates: # enter
            expert_outputs = expert_outputs * self.batch_gates[:, None]

        zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
        y = zeros.index_add(0, self.batch_index, expert_outputs)
        y = y.view(bsz, length, self.input_size) # (B, T, 768)

        return y

class RandomMoE(MoE):

    """Call a Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
    Args:
    input_size: integer - size of the input
    output_size: integer - size of the input
    num_experts: an integer - number of experts
    hidden_size: an integer - hidden size of the experts
    noisy_gating: a boolean
    k: an integer - how many experts to use for each batch element
    """

    def __init__(self, *args, **kwargs):
        super(RandomMoE, self).__init__(*args, **kwargs)
        del self.f_gate
        del self.w_noise

    def top_k_gating(self, x, skip_mask=None, sample_topk=0, noise_epsilon=1e-2):
        """Noisy top-k gating.
          See paper: https://arxiv.org/abs/1701.06538.
          Args:
            x: input Tensor with shape [batch_size, input_size]
            train: a boolean - we only add noise at training time.
            noise_epsilon: a float
          Returns:
            gates: a Tensor with shape [batch_size, num_experts]
            load: a Tensor with shape [num_experts]
        """
        clean_logits = torch.randn(x.shape[0], self.num_experts).to(x.device)
        clean_logits = clean_logits * 0.001 + 1 # make the weight similar
        logits = clean_logits

        if skip_mask is not None:
            probs = torch.masked_fill(
                torch.softmax(logits, dim=1), skip_mask, 0)
        else:
            probs = torch.softmax(logits, dim=1)

        if self.training and (sample_topk > 0):
            # top_k_indices = torch.multinomial(probs + 1e-6, self.k)
            # top_k_gates = torch.gather(probs, 1, top_k_indices)
            assert sample_topk <= self.k

            _, top_km1_indices = probs.topk(self.k - sample_topk, dim=1)
            masked_probs = probs + 1e-6
            masked_probs[torch.arange(probs.size(0)).unsqueeze(
                1), top_km1_indices] = 0
            k_indices = torch.multinomial(masked_probs, sample_topk)
            top_k_indices = torch.cat([top_km1_indices, k_indices], dim=-1)
            top_k_gates = torch.gather(probs, 1, top_k_indices)
        else:
            top_k_gates, top_k_indices = probs.topk(self.k, dim=1)
            
        # top_k_gates = top_k_gates / \
        #     (top_k_gates.sum(dim=1, keepdim=True) + 1e-6).detach()
        
        zeros = torch.zeros_like(probs, requires_grad=True)
        gates = zeros.scatter(1, top_k_indices, top_k_gates)
        self.expert_size = (gates > 0).long().sum(0)

        top_k_gates = top_k_gates.flatten()
        top_k_experts = top_k_indices.flatten()
        
        nonzeros = top_k_gates.nonzero().squeeze(-1)
        top_k_experts_nonzero = top_k_experts[nonzeros]

        _, _index_sorted_experts = top_k_experts_nonzero.sort(0)
        self.index_sorted_experts = nonzeros[_index_sorted_experts]
        self.batch_index = self.index_sorted_experts.div(self.k, rounding_mode='trunc') 
        self.batch_gates = top_k_gates[self.index_sorted_experts]

        loss = 0
        # loss += self.cvloss * self.compute_cvloss(gates)
        # loss += self.switchloss * \
        #     self.compute_switchloss(probs, self.expert_size)
        # loss += self.zloss * self.compute_zloss(logits)
        return loss
