#!/usr/bin/env python
# coding=utf-8
# @Author: jiangpeijie.jpj
# @Date: Mon 4 Dec 2023 05:21:28 PM CST

import logging
import math
from dataclasses import dataclass
from typing import Optional

import torch
import torch.nn.functional as F
import torch.distributed as dist
from numpy import random
from torch import nn
from torch.nn import CrossEntropyLoss
from whisper.model import AudioEncoder

from transformers.activations import ACT2CLS, ClassInstantier

try:
    from atorch.distributed.distributed import parallel_group, parallel_group_size
except Exception:
    parallel_group = None
    parallel_group_size = None


# ## Activations
class SwiGLUActivatition(nn.Module):

    def forward(self, input):
        input = torch.chunk(input, 2, dim=-1)
        return F.silu(input[0]) * input[1]


ACT2CLS["swiglu"] = SwiGLUActivatition
ACT2FN = ClassInstantier(ACT2CLS)


def get_activation(activation_string):
    if activation_string in ACT2FN:
        return ACT2FN[activation_string]
    else:
        raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")


# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
swiglu = get_activation("swiglu")


# Rotary Position Embedding Utils
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))


def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
    """Find dim range bounds based on rotations"""
    low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
    high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case


def linear_ramp_mask(min, max, dim):
    if min == max:
        max += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func


def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    # dim=-1 triggers a bug in earlier torch versions
    return torch.cat((-x2, x1), dim=x1.ndim - 1)


# Comment torchscript func for accurate calculate
# @torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
    # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
    cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
        position_id, sin.squeeze(1)
    ).unsqueeze(2)
    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
    return q, k


class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        # inv_freq 保留float精度，避免bf16损失
        # inv_freq = inv_freq.to(precision)
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(
        self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        pass

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
            t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
            # freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
            freqs = torch.outer(t, inv_freq.to(x.device))
            assert freqs.dtype == torch.float32
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn, *args, **kwargs):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn, *args, **kwargs)


class LinearScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with linear scaling."""

    def __init__(
        self, dim, base=10000, precision=torch.half, learnable=False, max_embedding_length=2048, scaling_factor=1.0
    ):
        self.scaling_factor = scaling_factor
        self.max_embedding_length = max_embedding_length
        super().__init__(dim, base, precision, learnable)

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len

            inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
            t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
            t = t / self.scaling_factor
            # freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
            freqs = torch.outer(t, inv_freq.to(x.device))
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]


class NTKScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with Dynamic NTK scaling."""

    def __init__(
        self, dim, base=10000, precision=torch.half, learnable=False, max_embedding_length=2048, scaling_factor=1.0
    ):
        self.scaling_factor = scaling_factor
        self.max_embedding_length = max_embedding_length
        super().__init__(dim, base, precision, learnable)

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len

            base = self.base
            if seq_len > self.max_embedding_length:
                base = self.base * (
                    (self.scaling_factor * seq_len / self.max_embedding_length) - (self.scaling_factor - 1)
                ) ** (self.dim / (self.dim - 2))

            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
            t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
            # freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
            freqs = torch.outer(t, inv_freq.to(x.device))
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]


class DynamicYaRNScaledRotaryEmbedding(RotaryEmbedding):
    def __init__(
        self,
        dim,
        base=10000,
        precision=torch.half,
        learnable=False,
        max_embedding_length=2048,
        extrapolation_factor=1,
        attn_factor=1,
        beta_fast=32,
        beta_slow=1,
    ):
        self.max_embedding_length = max_embedding_length
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow

        super().__init__(dim, base, precision, learnable)

    def forward(self, x, seq_dim=1, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len is None:
            seq_len = x.shape[seq_dim]

        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = seq_len

            if seq_len > self.max_embedding_length:
                self.yarn(seq_len / self.max_embedding_length, x.device)

            t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
            # freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
            freqs = torch.outer(t, self.inv_freq.to(x.device))
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached

        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def yarn(self, scale, device):
        pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scale * pos_freqs)

        low, high = find_correction_range(
            self.beta_fast, self.beta_slow, self.dim, self.base, self.max_embedding_length
        )
        inv_freq_mask = (
            1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
        ) * self.extrapolation_factor  # Get n-d rotational scaling corrected for extrapolation
        inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

        self.register_buffer("inv_freq", inv_freq)  # Get n-d magnitude scaling corrected for interpolation


# ## LongGLM Utils


@dataclass
class LongGLMMemCache:
    """
    Class with LongLlama's memory cache

    Args:
        key (`torch.FloatTensor` of shape `(batch_size, mem_length, head_nums, embed_size_per_head)`)
        value (`torch.FloatTensor` of shape `(batch_size, mem_length, head_nums, embed_size_per_head)`)
        masks (`torch.FloatTensor` of shape `(batch_size, 1, mem_length, 1)`)
            For masking out parts of memory
    """

    key: torch.FloatTensor
    value: torch.FloatTensor
    masks: torch.FloatTensor


def mem_apply_update(prev_external_mem_cache: LongGLMMemCache, new_mem_content: LongGLMMemCache):

    def update_one(prev, new, dim=1):
        if len(prev.shape) != len(new.shape):
            raise ValueError(f"Memory cache content should be consistent in shape got {prev.shape} {new.shape}")

        return torch.concat([prev, new], dim=dim)

    insert_size = new_mem_content.key.shape[1]

    assert new_mem_content.key.shape[1] == new_mem_content.value.shape[1]
    if new_mem_content.masks.shape[-2] != insert_size:
        raise ValueError("Inconsistent mem_length in new_mem_content")

    return LongGLMMemCache(
        key=update_one(prev_external_mem_cache.key, new_mem_content.key),
        value=update_one(prev_external_mem_cache.value, new_mem_content.value),
        masks=update_one(prev_external_mem_cache.masks, new_mem_content.masks, dim=-2),
    )


def generate_prompt_keypass(n_garbage: int, seed: int = None):
    """Generates a text file and inserts an execute line at a random position."""
    if seed is not None:
        rnd_state = random.get_state()
        random.seed(seed)
    n_garbage_prefix = random.randint(0, n_garbage)
    n_garbage_suffix = n_garbage - n_garbage_prefix

    task_description = "在下文的大量无关紧要的文字中隐藏着一个非常重要的信息，请找到并记住它们，后面将使用到这个信息。"
    garbage = "草是绿色的。天空是蓝色的。太阳是黄色的。我们走。我们离开又回来了。"
    garbage_inf = "".join([garbage] * 5000)
    assert len(garbage_inf) >= n_garbage
    garbage_prefix = garbage_inf[:n_garbage_prefix]
    garbage_suffix = garbage_inf[:n_garbage_suffix]
    pass_key = random.randint(1, 50000)
    information_line = (
        f"以下是本段文本的重要信息: “通行密码是'{pass_key}'，这是非常重要的信息，请记住'{pass_key}'是通行密码。”"
    )
    information_line = "\n".join([information_line] * 3)
    final_question = "请问通行密码是多少？"
    lines = [
        task_description,
        garbage_prefix,
        information_line,
        garbage_suffix,
        final_question,
    ]
    if seed is not None:
        random.set_state(rnd_state)
    return "\n".join(lines), str(pass_key)


# ## Loss Fuctions


def _unpack_router_logits(router_outputs):
    """
    Unpack the router tuple for blance loss calculation.
    """
    total_router_logits = []
    total_expert_indexes = []
    for router_output in router_outputs:
        if router_output[0] is not None:
            router_logits, expert_indexes = router_output
            total_router_logits.append(router_logits.unsqueeze(0))
            total_expert_indexes.append(expert_indexes.unsqueeze(0))
    # return torch.cat(total_router_logits, dim=0), torch.cat(total_expert_indexes, dim=0)
    return torch.cat(total_router_logits, dim=0), total_expert_indexes


def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor, labels: torch.Tensor) -> float:
    r"""
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.

    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        router_probs (`torch.Tensor`):
            Probability assigned to each expert per token. Shape: [num_layers, batch_size, seqeunce_length, num_experts].
        expert_indices (`torch.Tensor`):
            Indices tensor of shape [num_layers, batch_size, seqeunce_length] identifying the selected expert for a given token.

    Returns:
        The auxiliary loss.
    """

    num_layers, _, seq_len, num_experts = router_probs.shape
    num_experts = router_probs.shape[-1]
    new_labels = labels.clone().detach()
    ##
    for batch_tensor in new_labels:
        neg_mask = batch_tensor == -100
        diff_neg_ones = torch.diff(neg_mask.float())
        start_pos = torch.where(diff_neg_ones == 1.0)[0]  # 找到-1序列开始的位置
        if start_pos.nelement() == 0:  # 如果没有找到开始位置，可能需要根据实际情况调整
            pass
        else:
            last_start = start_pos[-1]  # 需要修改的最后一串-1的开始位置
            batch_tensor[:last_start] = 0  # 将这部分-1全部改为0
    new_labels = new_labels.to(torch.int64)

    # cast the expert indices to int64, otherwise one-hot encoding will fail

    if expert_indices.dtype != torch.int64:
        expert_indices = expert_indices.to(torch.int64)

    if len(expert_indices.shape) == 3:
        expert_indices = expert_indices.unsqueeze(3)

    expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)

    # For a given token, determine if it was routed to a given expert.
    expert_mask = torch.max(expert_mask, axis=-2).values

    # cast to float32 otherwise mean will fail
    expert_mask = expert_mask.to(torch.float32)
    labels_mask = (new_labels[None, ..., None].expand_as(expert_mask) != -100).long()

    # sample level balance loss
    tokens_per_group_and_expert = torch.sum(expert_mask * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
    router_prob_per_group_and_expert = torch.sum(router_probs * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
    tmp_per_group_and_expert = torch.mean(expert_mask)
    return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
    '''
    # batch level balance loss
    expert_mask = expert_mask.view(num_layers, -1, num_experts).detach()
    labels_mask = labels_mask.view(num_layers, -1, num_experts).detach()
    origin_mask = labels_mask.clone()
    router_probs = router_probs.view(num_layers, -1, num_experts)

    from antllm.utils import mpu

    torch.distributed.all_reduce(expert_mask, group=mpu.get_data_parallel_group())
    torch.distributed.all_reduce(labels_mask, group=mpu.get_data_parallel_group())

    labels_mask = labels_mask.bool().long()

    world_size = torch.distributed.get_world_size()

    tokens_per_group_and_expert = (
        torch.sum(expert_mask * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2) / world_size
    )
    router_prob_per_group_and_expert = torch.sum(router_probs * origin_mask, dim=-2) / torch.sum(origin_mask, dim=-2)
    layer_loss = tokens_per_group_and_expert * router_prob_per_group_and_expert
    loss = layer_loss.sum(-1).mean() * num_experts
    return loss
    '''


def group_level_device_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
    r"""
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.

    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        router_probs (`torch.Tensor`):
            Probability assigned to each expert per token. Shape: [num_layers, batch_size, seqeunce_length, num_experts].
        expert_indices (`torch.Tensor`):
            Indices tensor of shape [num_layers, batch_size, seqeunce_length] identifying the selected expert for a given token.

    Returns:
        The auxiliary loss.
    """
    assert parallel_group is not None and parallel_group_size is not None

    num_layers, _, seq_len, num_experts = router_probs.shape

    # cast the expert indices to int64, otherwise one-hot encoding will fail
    if expert_indices.dtype != torch.int64:
        expert_indices = expert_indices.to(torch.int64)

    if len(expert_indices.shape) == 3:
        expert_indices = expert_indices.unsqueeze(3)

    expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)

    # For a given token, determine if it was routed to a given expert.
    expert_mask = torch.max(expert_mask, axis=-2).values

    # cast to float32 otherwise mean will fail
    expert_mask = expert_mask.to(torch.float32)
    torch.distributed.all_reduce(expert_mask, group=parallel_group("expert"))
    expert_parallel_size = parallel_group_size("expert")
    num_experts_per_device = num_experts / expert_parallel_size

    # sample level balance loss
    expert_mask = torch.sum(
        torch.cat(torch.chunk(expert_mask.unsqueeze(-2), expert_parallel_size, dim=-1), dim=-2), dim=-1
    )
    tokens_per_group_and_device = torch.mean(expert_mask, axis=-2) / expert_parallel_size

    router_probs = torch.sum(
        torch.cat(torch.chunk(router_probs.unsqueeze(-2), expert_parallel_size, dim=-1), dim=-2), dim=-1
    )
    router_prob_per_group_and_device = torch.mean(router_probs, axis=-2)

    device_loss = tokens_per_group_and_device * router_prob_per_group_and_device * expert_parallel_size
    loss = device_loss.sum(-1).mean()

    return loss


def router_z_loss_func(router_logits: torch.Tensor, labels: torch.Tensor) -> float:
    r"""
    Compute the router z-loss implemented in PyTorch.

    The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
    It encourages router logits to remain small in an effort to improve stability.

    Args:
        router_logits (`float`):
            Input logits of shape [num_layers, batch_size, sequence_length, num_experts]

    Returns:
        Scalar router z-loss.
    """
    num_layers, num_groups, tokens_per_group, _ = router_logits.shape
    labels_mask = (labels[None, ..., None].expand_as(router_logits) != -100).long()

    ori_dtype = router_logits.dtype
    if ori_dtype == torch.bfloat16:
        loss_func_inputs = (router_logits * labels_mask).to(torch.float32)
    else:
        loss_func_inputs = router_logits * labels_mask
    log_z = torch.logsumexp(loss_func_inputs, dim=-1).to(ori_dtype)
    z_loss = log_z**2

    # log_z = torch.logsumexp(router_logits * labels_mask, dim=-1)
    # z_loss = log_z**2

    return torch.sum(z_loss) / (num_layers * num_groups * tokens_per_group)


def auxiliary_loss(outputs, labels):
    router_tuple = outputs.router_tuple
    balance_loss, z_loss, last_logits_l2_loss = 0.0, 0.0, 0.0

    loss = 0
    if router_tuple is not None:
        router_logits, layer_router_index = _unpack_router_logits(router_tuple)
        top1_expert_index = torch.cat(layer_router_index, dim=0)
        outputs["layer_expert_index"] = top1_expert_index
        z_loss = router_z_loss_func(router_logits, labels)
        router_probs = torch.nn.Softmax(dim=-1)(router_logits)
        balance_loss = load_balancing_loss_func(router_probs, top1_expert_index, labels)

        num_layers = router_probs.shape[0]
        num_experts = router_probs.shape[-1]
        router_probs_log = router_probs.detach().view(num_layers, -1, num_experts)
        router_probs_mean = router_probs_log.mean(1)
        router_probs_sort_mean = router_probs_log.sort(-1, descending=True)[0].mean(1)
        router_probs_log = torch.stack([router_probs_mean, router_probs_sort_mean], dim=1)
        dist.all_reduce(router_probs_log, dist.ReduceOp.SUM)
        router_probs_log = router_probs_log / torch.distributed.get_world_size()
        if dist.get_rank() == 0:
            router_probs_log = router_probs_log.float()
            router_probs_log /= router_probs_log.sum(-1, keepdim=True)
            outputs["layer_expert_probs"] = router_probs_log.float().cpu()

        group_balance_loss = 0
        if float(outputs["router_group_balance_loss_alpha"]) > 0:
            group_balance_loss = group_level_device_balancing_loss_func(router_probs, top1_expert_index)
        loss = (
            float(outputs["router_z_loss_alpha"]) * z_loss
            + float(outputs["router_balance_loss_alpha"]) * balance_loss
            + float(outputs["router_group_balance_loss_alpha"]) * group_balance_loss
        )

    last_logits_l2_loss = 0.0
    if float(outputs["last_logits_l2_alpha"]) >= 0:
        logits = outputs.logits.view(-1, outputs.logits.size(-1))
        labels_mask = (labels.view(-1) != -100).long()

        last_logits_l2_loss = torch.sum(torch.linalg.norm(logits.float(), 2.0, dim=-1) * labels_mask) / torch.sum(
            labels_mask
        )
        loss += float(outputs["last_logits_l2_alpha"]) * last_logits_l2_loss
        last_logits_l2_loss = last_logits_l2_loss.item()

    return loss, balance_loss, z_loss, last_logits_l2_loss


def expert_balanced_auxiliary_cross_entropy(outputs, labels, *args, **kwargs):
    """FOR PRETRAIN ONLY"""
    # Output losses without reduction for compute dataset loss
    if kwargs.get("output_losses", False):
        lm_loss, losses = cross_entropy_loss(outputs.logits, labels, *args, **kwargs)
    else:
        lm_loss = cross_entropy_loss(outputs.logits, labels, *args, **kwargs)
    aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(outputs, labels)
    loss = lm_loss + aux_loss
    if kwargs.get("output_losses", False):
        return loss, lm_loss, balance_loss, z_loss, last_logits_l2_loss, losses
    return loss, lm_loss, balance_loss, z_loss, last_logits_l2_loss


def expert_balanced_auxiliary_cross_entropy_for_sft(outputs, labels, *args, **kwargs):
    """FOR SFT ONLY"""
    lm_loss = sample_level_cross_entropy(outputs, labels, **kwargs)
    aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(outputs, labels)
    loss = lm_loss + aux_loss
    return loss


def expert_balanced_auxiliary_global_level_cross_entropy(outputs, labels, *args, **kwargs):
    """FOR SFT ONLY"""
    lm_loss = global_token_level_cross_entropy(outputs, labels, **kwargs)
    aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(outputs, labels)
    loss = lm_loss + aux_loss

    return [
        loss,
        {
            'aux_loss': aux_loss,
            'balance_loss': balance_loss,
            'z_loss': z_loss,
            'last_logits_l2_loss': last_logits_l2_loss,
        },
    ]


def cross_entropy_loss(logits, labels, loss_mask, *args, **kwargs):
    if kwargs["use_atorch_cross_entropy"]:
        from atorch.modules.transformer import losses as atorch_loss

        losses = atorch_loss.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1))
    else:
        losses = torch.nn.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1))

    loss = torch.sum(losses * loss_mask.view(-1))
    if loss_mask.sum().item() > 0:
        loss = loss / loss_mask.sum()
    if kwargs.get("output_losses", False):
        return loss, losses
    return loss


def local_token_level_cross_entropy(outputs, labels, **kwargs):
    # return outputs.loss / torch.distributed.get_world_size()
    # 在每个batch内部做token-level的平均,然后在所有batch间做平均
    # return outputs.loss
    loss_fct = CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1))

    return loss


def mini_batch_token_level_cross_entropy(outputs, labels, mini_batch=1, **kwargs):
    # 这个loss会先把batch分成小的mini_batch,在mini_batch内做个token-level的平均,然后做所有卡之间的平均
    loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
    if labels.shape[0] % mini_batch != 0:
        # 如果batch % mini_batch != 0, 则不切分计算. 有的数据量一个epoch结束的时候可能会出现这个情况
        loss_fct = CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1))
    else:
        loss = loss_fct(
            outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1)
        ).reshape(labels.shape[0] // mini_batch, -1)

        labels = labels.reshape(labels.shape[0] // mini_batch, -1)
        loss = loss.sum(-1) / (labels != -100).sum(-1)
        loss = loss.mean()
    return loss


def sample_level_cross_entropy(outputs, labels, **kwargs):
    # 先对所有样本字token-level的平均,然后计算所有sample的平均值
    loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
    loss = loss_fct(
        outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1)
    ).reshape(labels.shape[0], -1)
    loss = loss.sum(-1) / (labels != -100).sum(-1)
    loss = loss.mean()
    return loss


def global_token_level_cross_entropy(outputs, labels, **kwargs):
    # 对所有样本一起做token-level的平均
    loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
    loss = loss_fct(
        outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1)
    ).reshape(labels.shape[0], -1)
    num_tokens = (loss != 0).sum()
    loss = loss.sum()

    num_tokens_tensor = torch.zeros([1], device=loss.device, dtype=loss.dtype)
    num_tokens_tensor[0] = num_tokens.item()

    torch.distributed.all_reduce(num_tokens_tensor)

    global_num_tokens = num_tokens_tensor.sum()

    torch.distributed.barrier()
    # global_num_tokens是全局的token数，因为在梯度更新的时候回自动对所有卡求mean
    # 所有这里要乘一个world_size
    loss = loss.sum() / global_num_tokens * torch.distributed.get_world_size()

    return loss


LOSS_MAP = {
    'local_token_level_cross_entropy': local_token_level_cross_entropy,
    'mini_batch_token_level_cross_entropy': mini_batch_token_level_cross_entropy,
    'sample_level_cross_entropy': sample_level_cross_entropy,
    'global_token_level_cross_entropy': global_token_level_cross_entropy,
    "moe_auxiliary": expert_balanced_auxiliary_cross_entropy,
    "moe_auxiliary_sft": expert_balanced_auxiliary_cross_entropy_for_sft,
    "pretrain_default": cross_entropy_loss,
    "moe_auxiliary_global_token_level": expert_balanced_auxiliary_global_level_cross_entropy,
}

class Transpose(nn.Module):
    def __init__(self, dim0: int, dim1: int):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x):
        return x.transpose(self.dim0, self.dim1)

def patch_continuous_features(
    input_embeddings: torch.Tensor,
    placeholder_loc_lens: torch.Tensor,
    encoded_feats: torch.Tensor,
    encoded_feat_lens: torch.Tensor,
):
    """
    Patch continuous features into input embeddings, while keeping a valid gradient flow.

    input_embeddings: torch.Tensor, size = [B, C?, T, D]
    placeholder_loc_lens: torch.LongTensor, size = [B, N, 2]
        Each 2-tuple represents (start, length) of a placeholder.
    encoded_feats: torch.Tensor, size = [B, L1 + L2 + ... + LN, ...]
    encoded_feat_lens: torch.LongTensor, size = [B, N]

    Example ('X' for patch placeholder tokens):
    Inputs:
        input_embeddings = [[1, 2, 3, X, X, X, 4, 5, 6, X, X, X, 7, 8]]
        placeholder_loc_lens = [[3, 3], [9, 3]]
        encoded_feats = [[A, A, A, B, B]]
        encoded_feat_lens = [[3, 2]]
    Outputs:
        embeddings = [[1, 2, 3, A, A, A, 4, 5, 6, B, B, X, 7, 8]]
    """
    batch_size = input_embeddings.size(0)
    audio_feats_mask = torch.zeros_like(input_embeddings, dtype=torch.bool)
    audio_feats_buffer = []
    for i in range(batch_size):
        sample_len = 0
        audio_feat_start = 0
        audio_feat_buffer = []
        for j in range(placeholder_loc_lens.shape[1]):
            placeholder_start: int = int(placeholder_loc_lens[i, j, 0].item())
            placeholder_len: int = int(placeholder_loc_lens[i, j, 1].item())
            if placeholder_len <= 0:
                break
            feat_len = int(encoded_feat_lens[i, j].item())
            real_feat_len = feat_len
            if feat_len > placeholder_len:
                # logger.warning(
                #     f"Feature length ({feat_len}) > placeholder length ({placeholder_len}). This is not expected. Please "
                #     "check the implementation of estimate_audio_feature_length(). We truncate the feature to avoid errors."
                # )
                feat_len = placeholder_len
            if placeholder_start > sample_len:
                audio_feat_buffer.append(input_embeddings.new_zeros((placeholder_start - sample_len, input_embeddings.shape[2])))
                sample_len = placeholder_start
            audio_feat_buffer.append(encoded_feats[i, audio_feat_start:audio_feat_start + feat_len])
            if feat_len < placeholder_len:
                audio_feat_buffer.append(encoded_feats.new_zeros(placeholder_len - feat_len))
            audio_feats_mask[i, sample_len:sample_len + feat_len] = 1
            audio_feat_start += real_feat_len
            sample_len += placeholder_len
        if sample_len < input_embeddings.shape[1]:
            audio_feat_buffer.append(
                input_embeddings.new_zeros((input_embeddings.shape[1] - sample_len, input_embeddings.shape[2]))
            )
        audio_feats_buffer.append(torch.cat(audio_feat_buffer))
    audio_feats_buffer = torch.stack(audio_feats_buffer, dim=0)
    embeddings = audio_feats_buffer * audio_feats_mask + input_embeddings * ~audio_feats_mask
    return embeddings

def unwrap_feats(feats: torch.Tensor, feats_lengths: torch.Tensor):
    """
    The input feats are in the "wrapped" format, which means that features from (at most) N audios are concatenated
    as a single sample feats[i]. In this case, each row of feats_lengths contains the lengths of the concatenated
    feature. This function unwraps the features.
    For samples with less than N segments, one should pad feats_lengths with 0. The result will contain valid
    segments only.

    feats: torch.Tensor, size = [B, L1 + L2 + ... + LN, ...]
    feats_lengths: torch.LongTensor, size = [B, N]

    Example ('X' for padding):
    Inputs:
        feats = [[A, A, A, A, X],
                 [B, B, C, C, C]]
        feats_lengths = [[4, 0],
                         [2, 3]]
    Outputs:
        feat_segs = [[A, A, A, A],
                     [B, B, X, X],
                     [C, C, C, X]]
        feat_seg_lengths = [4, 2, 3]
    """
    feat_segs = []
    feat_seg_lengths = []
    for i in range(feats_lengths.shape[0]):
        feat_index = 0
        for j in range(feats_lengths.shape[1]):
            feat_len = feats_lengths[i, j].item()
            if feat_len == 0: break
            feat_segs.append(feats[i, feat_index:feat_index + feat_len])
            feat_seg_lengths.append(feat_len)
            feat_index += feat_len
    feat_segs_batch = torch.nn.utils.rnn.pad_sequence(feat_segs, True).to(feats.device)
    feat_seg_lengths = torch.tensor(feat_seg_lengths, dtype=torch.long, device=feats.device)
    return feat_segs_batch, feat_seg_lengths

def wrap_feats(feat_segs: torch.Tensor, feats_lengths: torch.Tensor, feats_seg_lengths: Optional[torch.Tensor] = None):
    """
    Wrap segmented features back to the wrapped format.
    This function is the inverse operation of unwrap_feats(). See its documentation for details.
    Note that the feats_lengths value does not matter a lot. We only check the location of the first 0 to determine the
    number of feature segments.
    """
    feat_idx = 0
    feats_buffer = []
    feats_locs_buffer = []
    feats_lengths_buffer = []
    for i in range(feats_lengths.shape[0]):
        feat_buffer = []
        feat_locs_buffer = []
        feat_lengths_buffer = []
        feat_total_len = 0
        for j in range(feats_lengths.shape[1]):
            feat_len = feats_lengths[i, j].item()
            if feat_len == 0:
                break
            if feats_seg_lengths is not None:
                feat_len = feats_seg_lengths[feat_idx].item()
            feat_buffer.append(feat_segs[feat_idx, :feat_len])
            feat_locs_buffer.append(feat_total_len)
            feat_lengths_buffer.append(feat_len)
            feat_idx += 1
            feat_total_len += feat_len
        feats_buffer.append(torch.cat(feat_buffer))
        feats_locs_buffer.append(torch.tensor(feat_locs_buffer, dtype=torch.long))
        feats_lengths_buffer.append(torch.tensor(feat_lengths_buffer, dtype=torch.long))
    feats = torch.nn.utils.rnn.pad_sequence(feats_buffer, True).to(feat_segs.device)
    feats_locs = torch.nn.utils.rnn.pad_sequence(feats_locs_buffer, True).to(feats_lengths.device)
    feats_new_lengths = torch.nn.utils.rnn.pad_sequence(feats_lengths_buffer, True).to(feats_lengths.device)
    return feats, feats_locs, feats_new_lengths

def encode_audio_segments(
    encoder,
    proj_layer,
    wav_feats=None,
    wav_feats_lengths=None,
    waveforms=None,
    waveforms_lengths=None,
    use_waveform=False,
    audio_config=None,
):
    """
    Apply audio encoder to input audio features in wrapped format.
    See the documentation of unwrap_feats() for details about 'wrapped format'.
    """

    # Forward audio encoder.
    if use_waveform:
        assert waveforms is not None and waveforms_lengths is not None
        # Unwrap the waveforms so each waveform is placed at an independent row.
        waveform_segs_batch, waveform_seg_lengths = unwrap_feats(waveforms, waveforms_lengths)
        audio_feats_seg, audio_feat_seg_lengths = encoder(waveform_segs_batch, waveform_seg_lengths)[:2]
    else:
        assert wav_feats is not None and wav_feats_lengths is not None
        # Unwrap the features so the feature of each waveform is placed at an independent row.
        feat_segs_batch, feat_seg_lengths = unwrap_feats(wav_feats, wav_feats_lengths)
        assert isinstance(encoder, AudioEncoder)
        # for whisper encoder
        # feat_segs_batch: [B, T, n_mels]
        # feat_seg_lengths: [B]
        audio_feats_seg = encoder(feat_segs_batch)
        audio_feats_seg_proj = proj_layer(audio_feats_seg.transpose(-1, -2)).transpose(-1, -2)
        feat_seg_lengths = feat_seg_lengths.to(feat_segs_batch.device)
        # whisper encoder conv
        audio_feat_seg_lengths = (feat_seg_lengths - 3 + 2 * 1) // 2 + 1
        # project layer conv
        audio_feat_seg_lengths = (audio_feat_seg_lengths - audio_config.ds_kernel_size + 2 *
                                  (audio_config.ds_kernel_size//2)) // audio_config.ds_stride + 1

    # Wrap the features so the 1st dim represents batch_size.
    input_lengths = waveforms_lengths if use_waveform else wav_feats_lengths
    assert input_lengths is not None
    audio_feats, _, audio_feats_lengths = wrap_feats(audio_feats_seg, input_lengths, audio_feat_seg_lengths)
    audio_feats_proj, _, audio_feats_lengths2 = wrap_feats(audio_feats_seg_proj, input_lengths, audio_feat_seg_lengths)
    assert torch.all(audio_feats_lengths == audio_feats_lengths2), f"{audio_feats_lengths}, {audio_feats_lengths2}"

    return audio_feats_proj, audio_feats, audio_feats_lengths

def patch_continuous_features(
    input_embeddings: torch.Tensor,
    placeholder_loc_lens: torch.Tensor,
    encoded_feats: torch.Tensor,
    encoded_feat_lens: torch.Tensor,
):
    """
    Patch continuous features into input embeddings, while keeping a valid gradient flow.

    input_embeddings: torch.Tensor, size = [B, C?, T, D]
    placeholder_loc_lens: torch.LongTensor, size = [B, N, 2]
        Each 2-tuple represents (start, length) of a placeholder.
    encoded_feats: torch.Tensor, size = [B, L1 + L2 + ... + LN, ...]
    encoded_feat_lens: torch.LongTensor, size = [B, N]

    Example ('X' for patch placeholder tokens):
    Inputs:
        input_embeddings = [[1, 2, 3, X, X, X, 4, 5, 6, X, X, X, 7, 8]]
        placeholder_loc_lens = [[3, 3], [9, 3]]
        encoded_feats = [[A, A, A, B, B]]
        encoded_feat_lens = [[3, 2]]
    Outputs:
        embeddings = [[1, 2, 3, A, A, A, 4, 5, 6, B, B, X, 7, 8]]
    """
    batch_size = input_embeddings.size(0)
    audio_feats_mask = torch.zeros_like(input_embeddings, dtype=torch.bool)
    audio_feats_buffer = []
    for i in range(batch_size):
        sample_len = 0
        audio_feat_start = 0
        audio_feat_buffer = []
        for j in range(placeholder_loc_lens.shape[1]):
            placeholder_start: int = int(placeholder_loc_lens[i, j, 0].item())
            placeholder_len: int = int(placeholder_loc_lens[i, j, 1].item())
            if placeholder_len <= 0:
                break
            feat_len = int(encoded_feat_lens[i, j].item())
            real_feat_len = feat_len
            if feat_len > placeholder_len:
                logging.warning(
                    f"Feature length ({feat_len}) > placeholder length ({placeholder_len}). This is not expected. Please "
                    "check the implementation of estimate_audio_feature_length(). We truncate the feature to avoid errors."
                )
                feat_len = placeholder_len
            if placeholder_start > sample_len:
                audio_feat_buffer.append(input_embeddings.new_zeros((placeholder_start - sample_len, input_embeddings.shape[2])))
                sample_len = placeholder_start
            audio_feat_buffer.append(encoded_feats[i, audio_feat_start:audio_feat_start + feat_len])
            if feat_len < placeholder_len:
                audio_feat_buffer.append(encoded_feats.new_zeros(placeholder_len - feat_len))
            audio_feats_mask[i, sample_len:sample_len + feat_len] = 1
            audio_feat_start += real_feat_len
            sample_len += placeholder_len
        if sample_len < input_embeddings.shape[1]:
            audio_feat_buffer.append(
                input_embeddings.new_zeros((input_embeddings.shape[1] - sample_len, input_embeddings.shape[2]))
            )
        audio_feats_buffer.append(torch.cat(audio_feat_buffer))
    audio_feats_buffer = torch.stack(audio_feats_buffer, dim=0)
    embeddings = audio_feats_buffer * audio_feats_mask + input_embeddings * ~audio_feats_mask
    return embeddings

def build_modality_mask(placeholder_loc_lens: torch.Tensor, shape: torch.Size):
    mask = torch.zeros(shape, dtype=torch.bool)
    for i in range(placeholder_loc_lens.shape[0]):
        for j in range(placeholder_loc_lens.shape[1]):
            start: int = int(placeholder_loc_lens[i, j, 0].item())
            length: int = int(placeholder_loc_lens[i, j, 1].item())
            if length <= 0:
                break
            mask[i, start:start + length] = True
    return mask
