
r"""
Noisy gate for gshard and switch
"""
from copy import deepcopy
from typing import Dict
from fmoe.layers import *
from fmoe.layers import _fmoe_general_global_forward
from fmoe.linear import FMoELinear

import tree
import os
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.nn import Dropout
import torch.nn.functional as F
import math

from fmoe.functions import prepare_forward, ensure_comm, count_by_gate
from fmoe.functions import MOEScatter, MOEGather
from fmoe.functions import AllGather, Slice
from fmoe.gates import NaiveGate, NoisyGate
from fmoe.gates.base_gate import BaseGate

from fmoe.fastermoe.config import switch_from_env

from perceiver_pytorch.perceiver_pytorch import PreNorm, Attention, default, rearrange, checkpoint, partial, einsum, exists, repeat

# from custom_model.mmoe_transformer_encoder import MultiModalityConfig

class MultiModalityConfig:
    num_experts = 16
    base_capacity = 16
    capacity_per_expert = 10
    gate = NoisyGate
    num_tasks = 1
    load_expert_count = False
    seed = 1
    img_path = 'log/exp_img'
    limited_capacity_on_mlp = True
    seperate_qkv = False
    co_input = False # input all modalities at the same time if co_input is True, otherwise, input each modality sequentially
    # task_contrastive = False
    attn_modality_specific = False
    mlp_modality_specific = False
    modalities_name = []
    modality_gating_merge = False
    task_gating_merge = False
    modality_remap = {}
    task_remap = {}
    capacity_ratio = 1.0
    capacity_ratios = None
    dynamic_reweight = False
    cross_modality_attn = False
    conditional_weight = False
    padding_prompt = False
    mlp_top_k = 2
    attn_top_k = 2
    cross_attn_use_moe = False
    auto_gate_loss = False
    gating_loss_map = None
    use_individual_latent_dim = False
    individual_latent_dim = {0: 12, 1: 12, 2: 20}
    outter_task_loss = False
    grad_clip_value = 1.
    modality_joint = False
    attn_use_moe = True
    mlp_use_moe = True
    equal_dense = False
    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)
            
    def setting_modality_remap(self, mapping:Dict):
        setattr(self, 'modality_remap', mapping)
        # clear modalities_name
        if hasattr(self, 'modalities_name'):
            self.modalities_name = []
        else:
            setattr(self, 'modalities_name', [])
        for key in self.modality_remap:
            if self.modality_remap[key] not in self.modalities_name:
                self.modalities_name.append(self.modality_remap[key])

    def setting_task_remap(self, mapping: Dict):
        setattr(self, 'task_remap', mapping)


class AttentionWithPrint(Attention):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0):
        super().__init__(query_dim, context_dim, heads, dim_head, dropout)
        self.printattn = None
        
    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim=-1)
        # Cast query and keys to float 32 to avoid instability as attention weights grow
        # during training, per https://twitter.com/tsuname/status/1430653484827697155?s=20
        k = k.float()
        q = q.float()

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = checkpoint(partial(einsum,'b i d, b j d -> b i j'), q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)
        self.printattn = attn
        out = checkpoint(partial(einsum,'b i j, b j d -> b i d'), attn, v)
        # cast back to input type:
        out = out.type(x.dtype)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

def _fmoe_limited_global_forward(inp, gate, gate_score, expert_fn, num_expert, world_size, expert_capacity, is_train, top_k = 2, **kwargs):
    r"""
    A private function that performs the following steps to complete the MoE
    computation.
    * Count the number of tokens from each worker to each expert.
    * Send the features to their target position so that input features to each
    expert are contiguous in memory.
    * Perform the forward computation of the experts using `expert_fn`
    * Gather the output features of experts back, and reorder them as sentences.
    Intermediate results like expert counts are hidden from users by this
    function.
    """
    (
        pos,
        local_expert_count,
        global_expert_count,
        fwd_expert_count,
        fwd_batch_size,
    ) = prepare_forward(gate, num_expert, world_size)
    topk = 1
    if len(gate.shape) == 2:
        topk = gate.shape[1]
    # print(topk, top_k)
    def scatter_func(tensor):
        return MOEScatter.apply(
            tensor,
            torch.div(pos, topk, rounding_mode='floor'),
            local_expert_count,
            global_expert_count,
            fwd_batch_size,
            world_size,
        )
    x = tree.map_structure(scatter_func, inp)
    # print(pos.shape, local_expert_count.shape, gate.shape, gate_score.shape, x.shape)
    # print(local_expert_count, fwd_expert_count)
    # if is_train:
    end_indexes = fwd_expert_count.cumsum(dim=0)
    start_indexes = deepcopy(end_indexes)
    start_indexes[1:] = end_indexes[:-1]
    start_indexes[0] = 0
    mask_scores , _= gate_score.max(dim=-1)
    mask = torch.ones_like(pos, dtype=torch.bool)
    drop_idx = {}
    select_idx = {}
    def generate_mask(pos_start, pos_end, index):
        # get token index
        token_pos = pos[pos_start:pos_end]
        token_number = pos_end - pos_start
        if token_number <= expert_capacity:
            return None
        expert_token_scores = mask_scores[(token_pos/top_k).to(torch.long)]
        drop_token_idx = token_pos[expert_token_scores.argsort()[:-expert_capacity]]
        select_token_idx = token_pos[expert_token_scores.argsort()[-expert_capacity:]]
        drop_idx[index] = drop_token_idx
        select_idx[index] = select_token_idx
        mask[drop_token_idx] = False
        fwd_expert_count[index] = expert_capacity
    tree.map_structure(generate_mask, start_indexes.tolist(), end_indexes.tolist(), [i for i in range(len(fwd_expert_count))])
    # print(mask)
    # delete mask
    def delete_mask_func(tensor):
        # back_up = tensor[mask == False, :]
        tensor = tensor[mask == True, :]
        return tensor
    exp_inp = tree.map_structure(delete_mask_func, x)
    # torch.Size([2127, 64]) 1709486208.0 359873536.0 1084355328.0
    # print(exp_inp.shape)
    exp_out = expert_fn(exp_inp, fwd_expert_count)
    # recover input tensor
    def recover_func(tensor):
        x[mask == True] = tensor
        return x
    x = tree.map_structure(recover_func, exp_out)
    # else:
    #     x = exp_out
    
    out_batch_size = tree.flatten(inp)[0].shape[0]
    if len(gate.shape) == 2:
        out_batch_size *= gate.shape[1]

    def gather_func(tensor):
        return MOEGather.apply(
            tensor,
            pos,
            local_expert_count,
            global_expert_count,
            out_batch_size,
            world_size,
        )

    outp = tree.map_structure(gather_func, x)
    return outp, drop_idx, select_idx

class NoisyVMoEGate(BaseGate):
    def __init__(self, d_model, num_expert, world_size, top_k=2):
        super().__init__(num_expert, world_size)
        self.w_gate = nn.Parameter(
            torch.zeros(d_model, self.tot_expert), requires_grad=True
        )
        self.w_noise = nn.Parameter(
            torch.zeros(d_model, self.tot_expert), requires_grad=True
        )
        self.top_k = top_k
        self.softplus = nn.Softplus()
        self.softmax = nn.Softmax(1)

        self.noise_epsilon = 1e-2
        self.output_logits = []
        self.topk_value = None

        self.reset_parameters()

    def reset_parameters(self):
        # Approach is the same as in torch.nn.Linear
        # https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/linear.py#L88

        torch.nn.init.kaiming_uniform_(self.w_gate, a=math.sqrt(5))
        torch.nn.init.kaiming_uniform_(self.w_noise, a=math.sqrt(5))


    def _gates_to_load(self, gates):
        """Compute the true load per expert, given the gates.
        The load is the number of examples for which the corresponding gate is >0.
        Args:
        gates: a `Tensor` of shape [batch_size, n]
        Returns:
        a float32 `Tensor` of shape [n]
        """
        return (gates > 0).sum(0)

    def _prob_in_top_k(
        self, clean_values, noisy_values, noise_stddev, noisy_top_values
    ):
        """Helper function to NoisyTopKGating.
        Computes the probability that value is in top k, given different random noise.
        This gives us a way of backpropagating from a loss that balances the number
        of times each expert is in the top k experts per example.
        In the case of no noise, pass in None for noise_stddev, and the result will
        not be differentiable.
        Args:
        clean_values: a `Tensor` of shape [batch, n].
        noisy_values: a `Tensor` of shape [batch, n].  Equal to clean values plus
          normally distributed noise with standard deviation noise_stddev.
        noise_stddev: a `Tensor` of shape [batch, n], or None
        noisy_top_values: a `Tensor` of shape [batch, m].
           "values" Output of tf.top_k(noisy_top_values, m).  m >= k+1
        Returns:
        a `Tensor` of shape [batch, n].
        """

        batch = clean_values.size(0)
        m = noisy_top_values.size(1)
        top_values_flat = noisy_top_values.flatten()
        threshold_positions_if_in = (
            torch.arange(batch, device=clean_values.device) * m + self.top_k
        )
        threshold_if_in = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_in), 1
        )
        is_in = torch.gt(noisy_values, threshold_if_in)
        threshold_positions_if_out = threshold_positions_if_in - 1
        threshold_if_out = torch.unsqueeze(
            torch.gather(top_values_flat, 0, threshold_positions_if_out), 1
        )
        # is each value currently in the top k.
        normal = Normal(
            torch.tensor([0.0], device=clean_values.device),
            torch.tensor([1.0], device=clean_values.device),
        )
        prob_if_in = normal.cdf((clean_values - threshold_if_in) / noise_stddev)
        prob_if_out = normal.cdf((clean_values - threshold_if_out) / noise_stddev)
        prob = torch.where(is_in, prob_if_in, prob_if_out)
        return prob

    def get_logits(self, clear = True):
        logits = self.output_logits
        if clear:
            self.output_logits = []
        logits_entro = []
        for i in range(len(logits)):
            prob = logits[i].softmax(dim=-1)
            logits_entro.append((-1 * (prob * torch.log(prob)).sum(dim=-1)).std())

        return logits_entro
    
    def set_logits(self, logits):
        
        self.output_logits.append(logits)
        # print(logits)
            
    def set_loss(self, loss):
        if self.loss is None:
            self.loss = loss
        else:
            self.loss += loss
        # return super().set_loss(loss)
        
    def set_topk(self, topk):
        self.top_k = topk

    def set_topk_logit(self, topk):
        # print(topk.shape)
        self.topk_value = topk
        
    def gate_topk_logits(self, clear=True):
        topk = self.topk_value
        if clear:
            self.topk_value = None
        return topk

    def cv_squared(self, x):
        """The squared coefficient of variation of a sample.
        Useful as a loss to encourage a positive distribution to be more uniform.
        Epsilons added for numerical stability.
        Returns 0 for an empty Tensor.
        Args:
        x: a `Tensor`.
        Returns:
        a `Scalar`.
        """
        eps = 1e-10
        # if only num_expert = 1
        if x.shape[0] == 1:
            return torch.Tensor([0])
        return x.float().var() / (x.float().mean() ** 2 + eps)

    def forward(self, inp):
        clean_logits = inp @ self.w_gate
        raw_noise_stddev = inp @ self.w_noise
        noise_stddev = (
            self.softplus(raw_noise_stddev) + self.noise_epsilon
        ) * self.training
        noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev)
        
        self.set_logits(clean_logits)
        # advance do softmax
        logits = self.softmax(noisy_logits)
        # calculate topk + 1 that will be needed for the noisy gates
        top_logits, top_indices = logits.topk(
            min(self.top_k + 1, self.tot_expert), dim=1
        )
        # top_logits = self.softmax()
        top_k_gates = top_logits[:, : self.top_k]
        top_k_indices = top_indices[:, : self.top_k]
        top_k_gates = self.softmax(top_k_gates)

        zeros = torch.zeros_like(logits, requires_grad=True)
        gates = zeros.scatter(1, top_k_indices, top_k_gates)

        if self.top_k < self.tot_expert and self.training:
            load = (
                self._prob_in_top_k(
                    clean_logits, noisy_logits, noise_stddev, top_logits
                )
            ).sum(0)
        else:
            load = self._gates_to_load(gates)

        importance = gates.sum(0)
        loss = self.cv_squared(importance) + self.cv_squared(load)
        self.set_loss(loss)
        
        self.set_topk_logit(top_k_indices)

        return (
            top_k_indices.contiguous().view(-1),
            top_k_gates.contiguous().unsqueeze(1),
        )

class ModifiedFMoE(FMoE):
    def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None, slice_group=None, moe_group=None, top_k=2, gate=..., expert=None, gate_hook=None, mask=None, mask_dict=None, capacity_per_expert = 0):
        super().__init__(num_expert, d_model, world_size, mp_group, slice_group, moe_group, top_k, gate, expert, gate_hook, mask, mask_dict)
        
        self.co_input_modalities_name = None

    def forward(self, moe_inp):
        r"""
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
        """

        moe_inp_batch_size = tree.flatten(
            tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
        )
        assert all(
            [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
        ), "MoE inputs must have the same batch size"

        if self.world_size > 1:

            def ensure_comm_func(tensor):
                ensure_comm(tensor, self.moe_group)

            tree.map_structure(ensure_comm_func, moe_inp)
        if self.slice_size > 1:

            def slice_func(tensor):
                return Slice.apply(
                    tensor, self.slice_rank, self.slice_size, self.slice_group
                )

            moe_inp = tree.map_structure(slice_func, moe_inp)
        if self.co_input_modalities_name is not None and hasattr(self, 'tasks_gates'):
            modality_step_in_inp = moe_inp.shape[0] // len(self.co_input_modalities_name)
            idx_list = []
            score_list = []
            for i in range(len(self.co_input_modalities_name)):
                idx, score = self.tasks_gates[self.co_input_modalities_name[i]](moe_inp[i * modality_step_in_inp: (i+1) * moe_inp])
                idx_list.append(idx)
                score_list.append(score)
            gate_top_k_idx = torch.cat(idx_list, dim=0)
            gate_score = torch.cat(score_list, dim=0)
        else:
            gate_top_k_idx, gate_score = self.gate(moe_inp)
        # torch.Size([160]) torch.Size([80, 1, 2])
        gate_score = gate_score.reshape(moe_inp.shape[0], self.top_k)
        gate_top_k_idx = gate_top_k_idx.reshape(moe_inp.shape[0], self.top_k)
        if self.gate_hook is not None:
            self.gate_hook(gate_top_k_idx, gate_score, None)
        # print(gate_top_k_idx.shape, gate_score.shape)
        # delete masked tensors
        if self.mask is not None and self.mask_dict is not None:
            # TODO: to fix
            def delete_mask_func(tensor):
                # to: (BxL') x d_model
                tensor = tensor[mask == 0, :]
                return tensor

            mask = self.mask.view(-1)
            moe_inp = tree.map_structure(delete_mask_func, moe_inp)
            gate_top_k_idx = gate_top_k_idx[mask == 0, :]

        fwd = _fmoe_general_global_forward(
            moe_inp, gate_top_k_idx, self.expert_fn,
            self.num_expert, self.world_size,
            experts=self.experts
        )

        # recover deleted tensors
        if self.mask is not None and self.mask_dict is not None:

            def recover_func(tensor):
                # to: (BxL') x top_k x dim
                dim = tensor.shape[-1]
                tensor = tensor.view(-1, self.top_k, dim)
                # to: (BxL) x top_k x d_model
                x = torch.zeros(
                    mask.shape[0],
                    self.top_k,
                    dim,
                    device=tensor.device,
                    dtype=tensor.dtype,
                )
                # recover
                x[mask == 0] = tensor
                for k, v in self.mask_dict.items():
                    x[mask == k] = v
                return x

            moe_outp = tree.map_structure(recover_func, fwd)
        else:

            def view_func(tensor):
                dim = tensor.shape[-1]
                tensor = tensor.view(-1, self.top_k, dim)
                return tensor

            moe_outp = tree.map_structure(view_func, fwd)

        gate_score = gate_score.view(-1, 1, self.top_k)

        def bmm_func(tensor):
            dim = tensor.shape[-1]
            tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
            return tensor

        moe_outp = tree.map_structure(bmm_func, moe_outp)

        if self.slice_size > 1:

            def all_gather_func(tensor):
                return AllGather.apply(
                    tensor, self.slice_rank, self.slice_size, self.slice_group
                )

            moe_outp = tree.map_structure(all_gather_func, moe_outp)

        moe_outp_batch_size = tree.flatten(
            tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
        )
        assert all(
            [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
        ), "MoE outputs must have the same batch size"
        return moe_outp

class LimitCapacityMoE(FMoE):
    """Modify bugs while using FMoE

    Args:
        FMoE (_type_): _description_
    """
    def __init__(self, num_expert=32, d_model=1024, world_size=1, mp_group=None, slice_group=None, moe_group=None, top_k=2, gate=NaiveGate, expert=None, gate_hook=None, mask=None, mask_dict=None, capacity_per_expert = 10):
        super().__init__(num_expert = num_expert, 
                         d_model = d_model, world_size = world_size, 
                         mp_group = mp_group, slice_group = slice_group, 
                         moe_group = moe_group, top_k = top_k, 
                         gate = gate, expert = expert, 
                         gate_hook = gate_hook, mask = mask, mask_dict = mask_dict)
        self.capacity_per_expert = capacity_per_expert
        self.co_input_modalities_name = None
        self.drop_idx = None
        self.select_idx = None
        self.batch_size = 0
        
    def set_capacity(self, capacity):
        self.capacity_per_expert = capacity
        
    def forward(self, moe_inp):
        r"""
        The FMoE module first computes gate output, and then conduct MoE forward
        according to the gate.  The score of the selected gate given by the
        expert is multiplied to the experts' output tensors as a weight.
        """
        
        moe_inp_batch_size = tree.flatten(
            tree.map_structure(lambda tensor: tensor.shape[0], moe_inp)
        )
        assert all(
            [batch_size == moe_inp_batch_size[0] for batch_size in moe_inp_batch_size]
        ), "MoE inputs must have the same batch size"

        if self.world_size > 1:

            def ensure_comm_func(tensor):
                ensure_comm(tensor, self.moe_group)

            tree.map_structure(ensure_comm_func, moe_inp)
        if self.slice_size > 1:

            def slice_func(tensor):
                return Slice.apply(
                    tensor, self.slice_rank, self.slice_size, self.slice_group
                )

            moe_inp = tree.map_structure(slice_func, moe_inp)

        if self.co_input_modalities_name is not None and hasattr(self, 'tasks_gates'):
            modality_step_in_inp = moe_inp.shape[0] // len(self.co_input_modalities_name)
            idx_list = []
            score_list = []
            for i in range(len(self.co_input_modalities_name)):
                # print(modality_step_in_inp, self.co_input_modalities_name[i], moe_inp.shape)
                if self.args.modality_gating_merge:
                    idx, score = self.tasks_gates[self.args.modality_remap[self.co_input_modalities_name[i]]](moe_inp[i * modality_step_in_inp: (i+1) * modality_step_in_inp])
                    self.top_k = self.tasks_gates[self.args.modality_remap[self.co_input_modalities_name[i]]].top_k
                else:
                    idx, score = self.tasks_gates[self.co_input_modalities_name[i]](moe_inp[i * modality_step_in_inp: (i+1) * modality_step_in_inp])
                    self.top_k = self.tasks_gates[self.co_input_modalities_name[i]].top_k

                idx_list.append(idx)
                score_list.append(score)
                
            gate_top_k_idx = torch.cat(idx_list, dim=0)
            gate_score = torch.cat(score_list, dim=0)
        else:
            gate_top_k_idx, gate_score = self.gate(moe_inp)
            self.top_k = self.gate.top_k

        # gate_top_k_idx, gate_score = self.gate(moe_inp)
        gate_score = gate_score.reshape(moe_inp.shape[0], self.top_k)
        gate_top_k_idx = gate_top_k_idx.reshape(moe_inp.shape[0], self.top_k)

        if self.gate_hook is not None:
            self.gate_hook(gate_top_k_idx, gate_score, None)
        
        # delete masked tensors
        if self.mask is not None and self.mask_dict is not None:
            # TODO: to fix
            def delete_mask_func(tensor):
                # to: (BxL') x d_model
                tensor = tensor[mask == 0, :]
                return tensor

            mask = self.mask.view(-1)
            moe_inp = tree.map_structure(delete_mask_func, moe_inp)
            gate_top_k_idx = gate_top_k_idx[mask == 0, :]

        fwd, drop_idx, select_idx = _fmoe_limited_global_forward(
            moe_inp, gate_top_k_idx, gate_score, self.expert_fn,
            self.num_expert, self.world_size,
            experts=self.experts, expert_capacity=self.capacity_per_expert, is_train=self.training, top_k=self.top_k
        )
        self.drop_idx = drop_idx
        self.select_idx = select_idx

        # recover deleted tensors
        if self.mask is not None and self.mask_dict is not None:

            def recover_func(tensor):
                # to: (BxL') x top_k x dim
                dim = tensor.shape[-1]
                tensor = tensor.view(-1, self.top_k, dim)
                # to: (BxL) x top_k x d_model
                x = torch.zeros(
                    mask.shape[0],
                    self.top_k,
                    dim,
                    device=tensor.device,
                    dtype=tensor.dtype,
                )
                # recover
                x[mask == 0] = tensor
                for k, v in self.mask_dict.items():
                    x[mask == k] = v
                return x

            moe_outp = tree.map_structure(recover_func, fwd)
        else:

            def view_func(tensor):
                dim = tensor.shape[-1]
                tensor = tensor.view(-1, self.top_k, dim)
                return tensor

            moe_outp = tree.map_structure(view_func, fwd)

        # normalize here
        gate_score = gate_score.view(-1, 1, self.top_k)

        def bmm_func(tensor):
            dim = tensor.shape[-1]
            tensor = torch.bmm(gate_score, tensor).reshape(-1, dim)
            return tensor

        moe_outp = tree.map_structure(bmm_func, moe_outp)

        if self.slice_size > 1:

            def all_gather_func(tensor):
                return AllGather.apply(
                    tensor, self.slice_rank, self.slice_size, self.slice_group
                )

            moe_outp = tree.map_structure(all_gather_func, moe_outp)

        moe_outp_batch_size = tree.flatten(
            tree.map_structure(lambda tensor: tensor.shape[0], moe_outp)
        )
        assert all(
            [batch_size == moe_outp_batch_size[0] for batch_size in moe_outp_batch_size]
        ), "MoE outputs must have the same batch size"
        return moe_outp


class MLPExpert(nn.Module):
    r"""
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
    """
    def __init__(self, num_expert, d_model, d_hidden, activation, rank=0, args = None, drop=0.):
        super().__init__()
        self.htoh4 = FMoELinear(num_expert, d_model, d_hidden, bias=True, rank=rank)
        self.h4toh = FMoELinear(num_expert, d_hidden, d_model, bias=True, rank=rank)
        self.activation = activation()
        self.drop1 = Dropout(drop)
        self.drop2 = Dropout(drop)
        self.expert_count = None
        self.args = args

    def get_expert_count(self):
        if self.args.load_expert_count:
            return self.expert_count
        
        return None

    def forward(self, inp, fwd_expert_count):
        r"""
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
        """
        if self.args.load_expert_count:
            self.expert_count = fwd_expert_count
        # print(fwd_expert_count, inp.shape)
        x = self.htoh4(inp, fwd_expert_count)
        x = self.drop1(x)
        # print(x.shape)
        x = self.activation(x)
        x = self.h4toh(x, fwd_expert_count)
        x = self.drop2(x)
        return x

class AttentionLinearExpert(nn.Module):
    r"""
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
    """

    def __init__(self, num_expert, in_dim, out_dim, bias, rank=0, args = None):
        super().__init__()
        self.qkv_linear = FMoELinear(num_expert, in_dim, out_dim, bias=bias, rank=rank)
        self.args = args
        self.expert_count = None

    def get_loss(self, clear=True):
        loss = self.loss
        if clear:
            self.loss = None
        return loss

    def get_expert_count(self):
        if self.args.load_expert_count:
            return self.expert_count
        
        return None

    def forward(self, inp, fwd_expert_count):
        r"""
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
        """
        if self.args.load_expert_count:
            self.expert_count = fwd_expert_count
        # print(fwd_expert_count, inp.shape)
        # print(inp.shape, fwd_expert_count)
        x = self.qkv_linear(inp, fwd_expert_count)
        return x

class VMoETransformerAttentionQKV(ModifiedFMoE):
    def __init__(self, 
                 num_expert=32, 
                 d_model=1024,
                 out_dim = 1024, 
                 bias = False,
                 expert_dp_comm="none",
                 world_size=1, 
                 mp_group=None, 
                 slice_group=None, 
                 moe_group=None, 
                 top_k=2, 
                 gate=NaiveGate, 
                 expert=None, 
                 gate_hook=None, 
                 mask=None, 
                 mask_dict=None,
                 args : MultiModalityConfig = None):
        super().__init__(num_expert, d_model, world_size, mp_group, slice_group, moe_group, top_k, gate, expert, gate_hook, mask, mask_dict)
        self.experts = AttentionLinearExpert(num_expert, d_model, out_dim, bias, args = args)
        self.out_dim = out_dim
        self.args = args
        # self.gate = None
        # self.register_module('gate', None)
        self.tasks_gates = nn.ModuleDict()
        if not args.attn_modality_specific:
            for i in range(args.num_tasks):
                self.tasks_gates[str(i)] = gate(d_model, num_expert, world_size, top_k)
        else:
            for ms in args.modalities_name:
                self.tasks_gates[ms] = gate(d_model, num_expert, world_size, top_k)
        # for i in range(args.num_tasks):
        #     self.tasks_gates[str(i)] = gate(d_model, num_expert, world_size, top_k)
        self.mark_parallel_comm(expert_dp_comm)
        self.expert_local_count = None
        
    # def expert_hook(self, gate_top_k_idx, gate_score, other):
    #     # self.gate_hook(gate_top_k_idx, gate_score, None)
    #     self.expert_local_count = 
        
    def gate_loss(self, task_idx = None, modality_name = None):
        # print(task_idx, modality_name)
        # print(self.args.modality_remap)
        # print(self.tasks_gates.keys())
        if self.args.attn_modality_specific:
            if self.args.modality_gating_merge:
                return self.tasks_gates[self.args.modality_remap[modality_name]].get_loss()
            else:
                return self.tasks_gates[modality_name].get_loss()
        else:
            return self.tasks_gates[str(task_idx)].get_loss()
        
    def get_expert_count(self):
        return self.experts.get_expert_count()
    def gate_topk_logits(self, task_idx = None, modality_name = None):
        if self.args.attn_modality_specific:
            if self.args.modality_gating_merge:
                return self.tasks_gates[self.args.modality_remap[modality_name]].gate_topk_logits()
            else:
                return self.tasks_gates[modality_name].gate_topk_logits()
        else:
            return self.tasks_gates[str(task_idx)].gate_topk_logits()
        
        
    def forward(self, inp: torch.Tensor, task_idx = None, modality_name = None):
        if type(modality_name) is list and self.args.attn_modality_specific:
            self.co_input_modalities_name = modality_name
        else:
            if self.args.attn_modality_specific:
                if self.args.modality_gating_merge:
                    self.gate = self.tasks_gates[self.args.modality_remap[modality_name]]
                else:
                    self.gate = self.tasks_gates[modality_name]
            else:
                self.gate = self.tasks_gates[str(task_idx)]
            self.co_input_modalities_name = None
        # self.gate = self.tasks_gates[str(task_idx)]
        output_shape = list(inp.shape)
        output_shape[-1] = self.out_dim
        # output_shape = 
        inp = inp.reshape(-1, self.d_model)
        output = super().forward(inp)
        return output.reshape(output_shape)
# ModifiedFMoE
# LimitCapacityMoE
class VMoETransformerMLP(LimitCapacityMoE):
    r"""
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
    """

    def __init__(
        self,
        num_expert=32,
        args : MultiModalityConfig= None,
        d_model=1024,
        d_hidden=4096,
        activation=torch.nn.GELU,
        expert_dp_comm="none",
        expert_rank=0,
        capacity_per_expert = 10,
        drop = 0.,
        **kwargs
    ):
        super().__init__(num_expert=num_expert, d_model=d_model, capacity_per_expert = capacity_per_expert, gate = args.gate, **kwargs)
        self.experts = MLPExpert(
            num_expert, d_model, d_hidden, activation, rank=expert_rank, args = args, drop = drop
        )
        self.args = args
        gate = args.gate
        self.batch_size = 0
        # self.gate = None
        self.tasks_gates = nn.ModuleDict()
        if self.args.mlp_modality_specific:
            for mn in self.args.modalities_name:
                self.tasks_gates[mn] = gate(d_model, num_expert,self.world_size, self.top_k)
        else:
            for i in range(self.args.num_tasks):
                if self.args.task_gating_merge:
                    self.tasks_gates[self.args.task_remap[str(i)]] = gate(d_model, num_expert,self.world_size, self.top_k)
                else:
                    self.tasks_gates[str(i)] = gate(d_model, num_expert,self.world_size, self.top_k)
        
        self.mark_parallel_comm(expert_dp_comm)
        self.modality_topk = None

    def set_modality_topk(self, modality_topk):
        self.modality_topk = modality_topk

    def set_topk(self, topk):
        for key in self.tasks_gates:
            self.tasks_gates[key].set_topk(topk)
    
    def gate_logits(self, task_idx):
        if str(task_idx) not in self.tasks_gates:
            return self.tasks_gates[str(0)].get_logits()
        return self.tasks_gates[str(task_idx)].get_logits()

    def gate_loss(self, task_idx = None, modality_name = None):
        if self.args.mlp_modality_specific:
            if self.args.modality_gating_merge:
                return self.tasks_gates[self.args.modality_remap[modality_name]].get_loss()
            else:
                return self.tasks_gates[modality_name].get_loss()
        else:
            if self.args.task_gating_merge:
                return self.tasks_gates[self.args.task_remap[str(task_idx)]].get_loss()
            else:
                return self.tasks_gates[str(task_idx)].get_loss()
            
    def gate_topk_logits(self, task_idx = None, modality_name = None):
        if self.args.mlp_modality_specific:
            if self.args.modality_gating_merge:
                return self.tasks_gates[self.args.modality_remap[modality_name]].gate_topk_logits()
            else:
                return self.tasks_gates[modality_name].gate_topk_logits()
        else:
            if self.args.task_gating_merge:
                return self.tasks_gates[self.args.task_remap[str(task_idx)]].gate_topk_logits()
            else:
                return self.tasks_gates[str(task_idx)].gate_topk_logits()
        
    def get_expert_count(self, num_modalities = 2):
        if not self.args.co_input:
            return self.experts.get_expert_count()
        else:
            index_range = self.batch_size * self.top_k // num_modalities
            # print(index_range, self.select_idx, self.experts.get_expert_count())
            all_expert_count = {}
            for i in range(num_modalities):
                expert_count = [0 for _ in range(self.num_expert)]
                for e in range(self.num_expert):
                    if e in self.select_idx:
                        lower_bound = self.select_idx[e] >= i * index_range
                        higher_bound = self.select_idx[e] < (i+1) * index_range
                        expert_count[e] = (lower_bound * higher_bound).sum().cpu().item()
                all_expert_count[i] = torch.tensor(expert_count)
            return all_expert_count

    def forward(self, inp: torch.Tensor, task_idx, modality_name = None):
        r"""
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
        """
        if type(modality_name) is list and self.args.mlp_modality_specific:
            self.co_input_modalities_name = modality_name
        else:
            if self.args.mlp_modality_specific:
                if self.args.modality_gating_merge:
                    self.gate = self.tasks_gates[self.args.modality_remap[modality_name]]
                else:
                    self.gate = self.tasks_gates[modality_name]
            else:
                if self.args.task_gating_merge:
                    self.gate = self.tasks_gates[self.args.task_remap[str(task_idx)]]
                else:
                    self.gate = self.tasks_gates[str(task_idx)]
            self.co_input_modalities_name = None
        if self.modality_topk is None:
            original_shape = inp.shape
            inp = inp.reshape(-1, self.d_model)
            self.batch_size = inp.shape[0]
            output = super().forward(inp)
            return output.reshape(original_shape)
        else:
            B, N, C = inp.shape
            modality_list = sorted(self.modality_topk.keys())
            n_modality = len(self.modality_topk)
            n_seq = N // n_modality
            mlp = []
            for i in range(n_modality):
                topk = self.modality_topk[modality_list[i]]
                self.set_topk(topk)
                x_input = inp[:, n_seq * i: n_seq * (i+1)]
                original_shape = x_input.shape
                x_input = x_input.reshape(-1, self.d_model)
                self.batch_size = x_input.shape[0]
                output = super().forward(x_input)
                mlp.append(output.reshape(original_shape))
            return torch.concat(mlp, dim=1)


class VMoETransformerMLPUnlimitCapacity(ModifiedFMoE):
    r"""
    A complete MoE MLP module in a Transformer block.
    * `activation` is the activation function to be used in MLP in each expert.
    * `d_hidden` is the dimension of the MLP layer.
    """

    def __init__(
        self,
        num_expert=32,
        args: MultiModalityConfig = None,
        d_model=1024,
        d_hidden=4096,
        activation=torch.nn.GELU,
        expert_dp_comm="none",
        expert_rank=0,
        capacity_per_expert = 10,
        gate = NoisyVMoEGate,
        drop=0.,
        **kwargs
    ):
        super().__init__(num_expert=num_expert, d_model=d_model, capacity_per_expert = capacity_per_expert, gate = gate, **kwargs)
        self.experts = MLPExpert(
            num_expert, d_model, d_hidden, activation, rank=expert_rank, args = args, drop=drop
        ) 
        self.args = args
        self.gate = args.gate
        # self.gate = None
        self.tasks_gates = nn.ModuleDict()
        if self.args.mlp_modality_specific:
            for mn in self.args.modalities_name:
                self.tasks_gates[mn] = gate(d_model, num_expert,self.world_size, self.top_k)
        else:
            for i in range(args.num_tasks):
                if self.args.task_gating_merge:
                    self.tasks_gates[self.args.task_remap[str(i)]] = gate(d_model, num_expert,self.world_size, self.top_k)
                else:
                    self.tasks_gates[str(i)] = gate(d_model, num_expert,self.world_size, self.top_k)
        
        self.mark_parallel_comm(expert_dp_comm)
        

    def gate_loss(self, task_idx = None, modality_name = None):
        if self.args.mlp_modality_specific:
            if self.args.modality_gating_merge:
                return self.tasks_gates[self.args.modality_remap[modality_name]].get_loss()
            else:
                return self.tasks_gates[modality_name].get_loss()
        else:
            if self.args.task_gating_merge:
                return self.tasks_gates[self.args.task_remap[str(task_idx)]].get_loss()
            else:
                return self.tasks_gates[str(task_idx)].get_loss()

    def get_expert_count(self):
        return self.experts.get_expert_count()

    def forward(self, inp: torch.Tensor, task_idx = None, modality_name = None):
        r"""
        This module wraps up the FMoE module with reshape, residual and layer
        normalization.
        """
        if type(modality_name) is list and self.args.mlp_modality_specific:
            self.co_input_modalities_name = modality_name
        else:
            if self.args.mlp_modality_specific:
                if self.args.modality_gating_merge:
                    self.gate = self.tasks_gates[self.args.modality_remap[modality_name]]
                else:
                    self.gate = self.tasks_gates[modality_name]
            else:
                if self.args.task_gating_merge:
                    self.gate = self.tasks_gates[self.args.task_remap[str(task_idx)]]
                else:
                    self.gate = self.tasks_gates[str(task_idx)]
            self.co_input_modalities_name = None
        original_shape = inp.shape
        inp = inp.reshape(-1, self.d_model)
        output = super().forward(inp)
        return output.reshape(original_shape)


class QKVSeperateExpert(nn.Module):
    r"""
    An expert using 2 FMoELinear modules to speed up the computation of experts
    within one worker.
    """

    def __init__(self, num_expert, in_dim, out_dim, bias, rank=0, args = None):
        super().__init__()
        self.seperate_qkv = FMoELinear(num_expert, in_dim, out_dim, bias=bias, rank=rank)
        self.args = args
        self.expert_count = None

    def get_loss(self, clear=True):
        loss = self.loss
        if clear:
            self.loss = None
        return loss

    def get_expert_count(self):
        if self.args.load_expert_count:
            return self.expert_count
        
        return None

    def forward(self, inp, fwd_expert_count):
        r"""
        First expand input to 4h (the hidden size is variable, but is called h4
        for convenience). Then perform activation. Finally shirink back to h.
        """
        if self.args.load_expert_count:
            self.expert_count = fwd_expert_count
        # print(fwd_expert_count, inp.shape)
        # print(inp.shape, fwd_expert_count)
        x = self.seperate_qkv(inp, fwd_expert_count)
        return x

class VMoETransformerSeperateQKV(LimitCapacityMoE):
    def __init__(self, 
                 num_expert=32, 
                 d_model=1024,
                 out_dim = 1024, 
                 bias = False,
                 expert_dp_comm="none",
                 world_size=1, 
                 mp_group=None, 
                 slice_group=None, 
                 moe_group=None, 
                 top_k=2, 
                 gate=NaiveGate, 
                 expert=None, 
                 gate_hook=None, 
                 mask=None, 
                 mask_dict=None,
                 args : MultiModalityConfig = None):
        super().__init__(num_expert, d_model, world_size, mp_group, slice_group, moe_group, top_k, gate, expert, gate_hook, mask, mask_dict)
        self.experts = QKVSeperateExpert(num_expert, d_model, out_dim, bias, args = args)
        self.capacity_per_expert = args.capacity_per_expert
        self.out_dim = out_dim
        # self.gate = None
        self.args = args
        # self.register_module('gate', None)
        self.tasks_gates = nn.ModuleDict()
        if not self.args.attn_modality_specific:
            for i in range(args.num_tasks):
                self.tasks_gates[str(i)] = gate(d_model, num_expert, world_size, top_k)
        else:
            for ms in self.args.modalities_name:
                self.tasks_gates[ms] = gate(d_model, num_expert, world_size, top_k)
        self.mark_parallel_comm(expert_dp_comm)
        self.expert_local_count = None
    
    # def expert_hook(self, gate_top_k_idx, gate_score, other):
    #     pass
    def set_topk(self, topk):
        for key in self.tasks_gates:
            self.tasks_gates[key].set_topk(topk)
    
    def gate_topk_logits(self, task_idx = None, modality_name = None):
        if self.args.attn_modality_specific:
            if self.args.modality_gating_merge:
                return self.tasks_gates[self.args.modality_remap[modality_name]].gate_topk_logits()
            else:
                return self.tasks_gates[modality_name].gate_topk_logits()
        else:
            return self.tasks_gates[str(task_idx)].gate_topk_logits()
        
    def gate_loss(self, task_idx = None, modality_name = None):
        if self.args.attn_modality_specific:
            if self.args.modality_gating_merge:
                return self.tasks_gates[self.args.modality_remap[modality_name]].get_loss()
            else:
                return self.tasks_gates[modality_name].get_loss()
        else:
            return self.tasks_gates[str(task_idx)].get_loss()
        
    def get_expert_count(self, num_modalities):
        if not self.args.co_input:
            return self.experts.get_expert_count()
        else:
            index_range = self.batch_size * self.top_k // num_modalities
            # print(index_range, self.select_idx, self.experts.get_expert_count())
            all_expert_count = {}
            for i in range(num_modalities):
                expert_count = [0 for _ in range(self.num_expert)]
                for e in range(self.num_expert):
                    if e in self.select_idx:
                        lower_bound = self.select_idx[e] >= i * index_range
                        higher_bound = self.select_idx[e] < (i+1) * index_range
                        expert_count[e] = (lower_bound * higher_bound).sum().cpu().item()
                all_expert_count[i] = torch.tensor(expert_count)
            return all_expert_count
        # return self.experts.get_expert_count()
        
    def forward(self, inp: torch.Tensor, task_idx = None, modality_name = None):
        if type(modality_name) is list and self.args.attn_modality_specific:
            self.co_input_modalities_name = modality_name
        else:
            if self.args.attn_modality_specific:
                if self.args.modality_gating_merge:
                    self.gate = self.tasks_gates[self.args.modality_remap[modality_name]]
                else:
                    self.gate = self.tasks_gates[modality_name]
            else:
                self.gate = self.tasks_gates[str(task_idx)]
            self.co_input_modalities_name = None
        output_shape = list(inp.shape)
        output_shape[-1] = self.out_dim
        # output_shape = 
        
        inp = inp.reshape(-1, self.d_model)
        self.batch_size = inp.shape[0]
        output = super().forward(inp)
        return output.reshape(output_shape)

if __name__ == '__main__':
    import torch
    fmoe_fnn = VMoETransformerAttention(8, 16, 32, True, gate=NoisyGate).cuda()

    inp = torch.randn(8, 10, 16).cuda()

    out = fmoe_fnn(inp)
    print(fmoe_fnn.gate_loss())
    print(out.shape)