# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast

import logging
import torch
from torch import Tensor

from .jit_compiler import IS_HIP_EXTENSION
from ..jit_kernels import sparse as jit_kernel
from ..jit_kernels.gating import fast_cumsum_sub_one
from .communicate import get_world_rank, simple_all_reduce 
from . import losses

class GatingEncoder(torch.autograd.Function):
    @staticmethod
    def forward(ctx: Any, config: Any, reshaped_input: Tensor, *gates_):
        ctx.reshaped_input = reshaped_input
        ctx.config = config
        if gates_:
        #   ctx.gates_h2 = [torch.ones_like(x).view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_]
          ctx.gates_h2 = [x.view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_]
        else:
          ctx.gates_h2 = [ctx.config.ones_helper] * len(ctx.config.indices_) # size为（token数目，2）的全1矩阵，一共k个

        # print(ctx.config.num_global_experts, ctx.config.capacity)

        dispatched_input = torch.zeros([ctx.config.num_global_experts * ctx.config.capacity, ctx.config.model_dim], dtype=reshaped_input.dtype, device=reshaped_input.device) 
        #capacity: topk * 平均每个expert的token数目的上界
        # 最后每个expert只会保留前capacity个input
        # print(ctx.config.num_global_experts, ctx.config.capacity)
        for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_):
          ctx.config.func_fwd(g, i, l, reshaped_input, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity, ctx.config.num_global_experts])
        return dispatched_input

    @staticmethod
    def backward(ctx: Any, dispatched_input: Tensor):
        dispatched_input = dispatched_input.contiguous()
        last_result = None
        for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_):
          grad_data = torch.empty(ctx.reshaped_input.shape, dtype=dispatched_input.dtype, device=dispatched_input.device)
          ctx.config.func_bwd_data(g, i, l, grad_data, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity, ctx.config.num_global_experts])
          last_result = grad_data if last_result is None else last_result + grad_data

        grad_gates = []
        if id(ctx.gates_h2[0]) != id(ctx.config.ones_helper):
          for i, l in zip(ctx.config.indices_, ctx.config.locations_):
            grad_gates1_s = torch.empty([ctx.config.sample_size,], dtype=dispatched_input.dtype, device=dispatched_input.device)
            ctx.config.func_bwd_gate(grad_gates1_s, i, l, ctx.reshaped_input, dispatched_input, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity, ctx.config.num_global_experts])
            grad_gates.append(grad_gates1_s)
        return (None, last_result, *grad_gates)


class GatingDecoder(torch.autograd.Function):
    @staticmethod
    def forward(ctx: Any, config: Any, expert_output: Tensor, *gates_):
        '''
        ctx: context
        config: configuration
        expert_output: output of the experts
        gates_: list of gates
        '''
        ctx.expert_output = expert_output
        ctx.config = config
        if gates_:
          ctx.gates_h2 = [x.view(-1, 1).repeat(1, 2) if x.dtype == torch.float16 else x for x in gates_]
        else:
          ctx.gates_h2 = [ctx.config.ones_helper] * len(ctx.config.indices_)

        last_result = None
        for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_):
          single_output = torch.empty([config.sample_size, config.model_dim], dtype=expert_output.dtype, device=expert_output.device)
          config.func_bwd_data(g, i, l, single_output, expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity, ctx.config.num_global_experts])
          last_result = single_output if last_result is None else last_result + single_output
        return last_result

    @staticmethod
    def backward(ctx: Any, combined_output: Tensor):
        combined_output = combined_output.contiguous()
        grad_expert_output = torch.zeros(ctx.expert_output.shape, dtype=combined_output.dtype, device=combined_output.device)
        for g, i, l in zip(ctx.gates_h2, ctx.config.indices_, ctx.config.locations_):
          ctx.config.func_fwd(g, i, l, combined_output, grad_expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity, ctx.config.num_global_experts])

        grad_gates = []
        if id(ctx.gates_h2[0]) != id(ctx.config.ones_helper):
          for i, l in zip(ctx.config.indices_, ctx.config.locations_):
            grad_gates1_s = torch.empty([ctx.config.sample_size,], dtype=combined_output.dtype, device=combined_output.device)
            ctx.config.func_bwd_gate(grad_gates1_s, i, l, combined_output, ctx.expert_output, extra=[ctx.config.indices_[0].size(0), ctx.config.aligned_dim, ctx.config.capacity, ctx.config.num_global_experts])
            grad_gates.append(grad_gates1_s)
        return (None, grad_expert_output, *grad_gates)


class TutelMoeFastDispatcher:

    kernel_pool = dict()
    ones_helper = None

    def __init__(self, num_global_experts, capacity, model_dim, dispatch_dtype):
        # capacity: 0
        # dispatch_dtype: data type
        self.num_global_experts = int(num_global_experts)
        self.capacity = int(capacity)
        self.model_dim = int(model_dim)
        self.dtype = dispatch_dtype
        if IS_HIP_EXTENSION or dispatch_dtype != torch.float16:
            self.dtype = torch.float32
        self.original_dtype = dispatch_dtype
        self.aligned_dim = model_dim // (2 if self.dtype == torch.float16 else 1)
        self.is_cuda = None

    def update(self, indices_, locations_, gates_, capacity=None, is_postscore=True):
        # indices_s：[所有token top-1 expert的index，..., 所有token top-k expert的index]
        # locations_s: [到当前token之前，该expert作为top-k出现的次数 + 对所有token该expert作为top-1到topk-1出现的次数]
        # gates_s：[所有token对其top-1 expert的score，..., 所有token对其top-k expert的score]
        # capacity：topk * 平均每个expert的token数目的上界
        self.indices_ = [x.to(torch.int32).view(-1) for x in indices_]
        self.locations_ = [x.to(torch.int32) for x in locations_]
        self.gates_ = [x.to(self.dtype) for x in gates_]
        self.is_postscore = is_postscore
        self.sample_size, self.capacity = int(self.indices_[0].size(0)), int(capacity) or self.capacity

        # token数目，capacity

        if self.is_cuda != indices_[0].is_cuda:
            self.is_cuda = indices_[0].is_cuda
            if self.is_cuda not in TutelMoeFastDispatcher.kernel_pool: # kernal_pool 初始化为空
                self.func_fwd = jit_kernel.create_forward_padding(self.dtype, indices_[0].is_cuda)
                self.func_bwd_data = jit_kernel.create_backward_data_padding(self.dtype, indices_[0].is_cuda)
                self.func_bwd_gate = jit_kernel.create_backward_gate_padding(self.dtype, indices_[0].is_cuda)
                TutelMoeFastDispatcher.kernel_pool[self.is_cuda] = self.func_fwd, self.func_bwd_data, self.func_bwd_gate
            else:
                self.func_fwd, self.func_bwd_data, self.func_bwd_gate = TutelMoeFastDispatcher.kernel_pool[self.is_cuda]

        if TutelMoeFastDispatcher.ones_helper is None or TutelMoeFastDispatcher.ones_helper.size(0) < self.sample_size:
            TutelMoeFastDispatcher.ones_helper = torch.ones([self.sample_size, 2], dtype=self.dtype, device=self.indices_[0].device)
        if TutelMoeFastDispatcher.ones_helper.is_cuda != self.indices_[0].is_cuda:
            TutelMoeFastDispatcher.ones_helper = torch.ones([TutelMoeFastDispatcher.ones_helper.size(0), 2], dtype=self.dtype, device=self.indices_[0].device)
        self.ones_helper = TutelMoeFastDispatcher.ones_helper

    def encode(self, data):
        if self.is_postscore:
            return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype)
        else:
            # return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype)
            return GatingEncoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype)

    def decode(self, data):
        if self.is_postscore:
            # return GatingEncoder.apply(self, data.to(self.dtype)).to(self.original_dtype)
            return GatingDecoder.apply(self, data.to(self.dtype), *self.gates_).to(self.original_dtype)
        else:
            return GatingDecoder.apply(self, data.to(self.dtype)).to(self.original_dtype)

fast_dispatcher = TutelMoeFastDispatcher

def compute_sorted_location(x, importance_scores):
    sorted_x = x[importance_scores.argsort(dim=0)]
    sorted_cumsum = fast_cumsum_sub_one(sorted_x) * sorted_x
    return sorted_cumsum[importance_scores.argsort(dim=0).argsort(dim=0)]

def extract_critical(scores, top_k, loss_fn=losses.gshard_loss, capacity_factor=1.0, batch_prioritized_routing=False, normalize_gate=True, alignment=1, group=None, inequivalent_tokens=False, one_score_gate=False):
    # top_k是sample数目长度的向量，储存每个token所对应的top-k。
    num_global_experts = int(scores.size(1)) #一共有多少个expert
    if isinstance(top_k, int):
        top_k = [top_k] * len(scores)
        max_top_k = max(top_k)
    else:
        max_top_k = max(top_k)
    # 我们将会始终保证topk小于等于expert数目，因此省略检测。

    #这里为了简化计算，我们对齐所有token的topk到最大的topk，并补充K作为无效indices值。
    topk_indices = [torch.topk(scores[i], top_k[i]).indices for i in range(len(scores))]

    topk_indices = torch.nn.utils.rnn.pad_sequence(topk_indices, True, padding_value=num_global_experts)

    # print(topk_indices)

    indices_s = [x.view(-1) for x in topk_indices.chunk(max_top_k, dim=1)] # [top 1 experts of all tokens, ..., top k experts of all tokens]

    # print(indices_s)

    masks_se = [losses._one_hot_with_dtype_and_padding(x, num_classes=num_global_experts, dtype=x.dtype) for x in indices_s] # 将indices_s变为one-hot vector，其中padding为K的部分将为全零向量。
    gates_s = [(scores * x).sum(dim=1) for x in masks_se] # 对应gate的score。

    l_loss = loss_fn(scores, topk_indices) if loss_fn is not None else None # 这部分还需要后面修改loss

    if batch_prioritized_routing:
        importance_scores = -1 * scores.max(dim=1)[0]
        compute_location = lambda x: compute_sorted_location(x, importance_scores)
    else:
        compute_location = fast_cumsum_sub_one

    locations1 = compute_location(masks_se[0]) # 对于所有token，统计直到这个token为止，第k个expert作为top-1 expert的次数。最后减1.

    locations_s = [torch.sum(locations1 * masks_se[0], dim=1).to(torch.int32)] # 对于每个token，统计直到这个token为止，当前token的top-1 expert作为top-1 expert的次数减1。

    if max_top_k > 1:
        acc_base = None
        for k in range(1, max_top_k):
            acc_base = torch.sum(masks_se[k - 1], dim=0, keepdim=True) if acc_base is None else acc_base + torch.sum(masks_se[k - 1], dim=0, keepdim=True) #统计K个expert分别作为top-1 到 top-（k-1） expert在所有token中出现的次数。
            locations2 = compute_location(masks_se[k]) # 对于所有token，统计直到这个token为止，所有expert作为top-k expert的次数。最后减1.
            locations2 += acc_base # 加上所有expert作为top 1-（k-1）出现的次数，最终为所有expert在这个token之前，作为top 1-（k）expert出现的次数。
            locations_s.append(torch.sum(locations2 * masks_se[k], dim=1).to(torch.int32)) # 对于每个token，统计直到这个token为止，当前token的top-k expert出现的次数减1。

        if normalize_gate and not one_score_gate:
            denom_s = torch.clamp(sum(gates_s), min=torch.finfo(gates_s[0].dtype).eps)
            gates_s = [x / denom_s for x in gates_s]

    if one_score_gate:
        gates_s = [torch.ones_like(gates_s[0])] * max_top_k
    
    indices_s = [x.to(torch.int32) for x in indices_s] # 改变indices的数据类型。

    if inequivalent_tokens:
        num_samples = torch.tensor(scores.size(0), device=scores.device)
        num_samples = int(simple_all_reduce(num_samples, group=group, op=torch.distributed.ReduceOp.MAX))
    else:
        num_samples = int(scores.size(0)) #所有token的数目。

    samples_per_expert = (num_samples + num_global_experts - 1) // num_global_experts # 平均每个expert的token数目的上界。
    if capacity_factor > 0: # 默认为1.0
        capacity = max_top_k * int(capacity_factor * samples_per_expert) # topk * 平均每个expert的token数目的上界
    else:
        capacity = torch.max(torch.cat(locations_s, dim=0))
        # print(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX))
        capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX)) + 1
        if capacity_factor < 0:
            capacity = min(capacity, max_top_k * int(-capacity_factor * samples_per_expert))


    remainder = capacity % alignment # alignment默认为1，所以默认为0。
    if remainder > 0:
        capacity = capacity + alignment - remainder
    

    if get_world_rank(group) == 0:
        logging.info(f"Capacity = {capacity}, real-time capacity-factor for top-{max_top_k} = {capacity / (max_top_k * samples_per_expert)}")

    # 返回值：
    # num_global_experts：总的expert数量。
    # indices_s：[所有token top-1 expert的index，..., 所有token top-k expert的index]，并padding补充K。
    # locations_s: [到当前token之前，该expert作为top-k出现的次数 + 对所有token该expert作为top-1到topk-1出现的次数]
    # gates_s：[所有token对其top-1 expert的score，..., 所有token对其top-k expert的score]
    # capacity：topk * 平均每个expert的token数目的上界
    # l_loss：gating 损失函数值

    return (num_global_experts, indices_s, locations_s, gates_s, capacity), l_loss

def fast_encode(data, critial_data, is_postscore=True):
    # critial_data：
    # [0] num_global_experts：总的expert数量。
    # [1] indices_s：[所有token top-1 expert的index，..., 所有token top-k expert的index]
    # [2] locations_s: [到当前token之前，该expert作为top-k出现的次数 + 对所有token该expert作为top-1到topk-1出现的次数]
    # [3] gates_s：[所有token对其top-1 expert的score，..., 所有token对其top-k expert的score]
    # [4] capacity：topk * 平均每个expert的token数目的上界
    # [5] l_loss：gating 损失函数值
    num_global_experts = critial_data[0]
    dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype)
    dispatcher.update(*critial_data[1:], is_postscore=is_postscore)
    return dispatcher.encode(data).view(num_global_experts, -1, data.size(-1))

def fast_decode(data, critial_data, is_postscore=True):
    num_global_experts = critial_data[0]
    dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, data.size(-1), data.dtype)
    dispatcher.update(*critial_data[1:], is_postscore=is_postscore)
    return dispatcher.decode(data).view(-1, data.size(-1))