# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from datetime import datetime

import torch
import torch.nn.functional as F

from swift.llm import git_clone_github
from swift.utils import (
    get_logger,
    is_megatron_available,
    safe_ddp_context,
    subprocess_run,
)

logger = get_logger()


def _patch_transformer_engine():
    import transformer_engine

    try:
        from transformer_engine.pytorch.attention import apply_rotary_pos_emb
    except ImportError:
        try:
            transformer_engine.pytorch.attention.apply_rotary_pos_emb = (
                transformer_engine.pytorch.attention.rope.apply_rotary_pos_emb
            )
            logger.info("Patch apply_rotary_pos_emb successfully applied.")
        except (ImportError, AttributeError):
            pass
    try:
        from transformer_engine.pytorch.attention import _SplitAlongDim
    except ImportError:
        try:
            transformer_engine.pytorch.attention._SplitAlongDim = (
                transformer_engine.pytorch.utils.SplitAlongDim
            )
            logger.info("Patch _SplitAlongDim successfully applied.")
        except (ImportError, AttributeError):
            pass


def _patch__batched_p2p_ops():
    from megatron.core.pipeline_parallel import p2p_communication

    _batched_p2p_ops_origin = p2p_communication._batched_p2p_ops

    def _batched_p2p_ops(**kwargs):
        kwargs["group"] = None
        return _batched_p2p_ops_origin(**kwargs)

    p2p_communication._batched_p2p_ops = _batched_p2p_ops


def _patch_training_log():
    # TODO: support swanlab
    from megatron.core import mpu
    from megatron.core.transformer.moe.moe_utils import track_moe_metrics
    from megatron.training.theoretical_memory_usage import report_theoretical_memory
    from megatron.core.transformer.multi_token_prediction import MTPLossLoggingHelper
    from megatron.training import (
        training,
        get_args,
        get_timers,
        get_tensorboard_writer,
        get_wandb_writer,
        get_one_logger,
        one_logger_utils,
        is_last_rank,
        print_rank_last,
    )
    from megatron.training.training import num_floating_point_operations
    from megatron.core.num_microbatches_calculator import get_num_microbatches
    from megatron.training.utils import (
        reduce_max_stat_across_model_parallel_group,
        report_memory,
    )

    # Code borrowed from NVIDIA/Megatron-LM
    def training_log(
        loss_dict,
        total_loss_dict,
        learning_rate,
        decoupled_learning_rate,
        iteration,
        loss_scale,
        report_memory_flag,
        skipped_iter,
        grad_norm,
        params_norm,
        num_zeros_in_grad,
    ):
        """Log training information such as losses, timing, ...."""
        args = get_args()
        timers = get_timers()
        writer = get_tensorboard_writer()
        wandb_writer = get_wandb_writer()

        # Advanced, skipped, and Nan iterations.
        advanced_iters_key = "advanced iterations"
        skipped_iters_key = "skipped iterations"
        nan_iters_key = "nan iterations"
        # Advanced iterations.
        if not skipped_iter:
            total_loss_dict[advanced_iters_key] = (
                total_loss_dict.get(advanced_iters_key, 0) + 1
            )
        else:
            if advanced_iters_key not in total_loss_dict:
                total_loss_dict[advanced_iters_key] = 0
        # Skipped iterations.
        total_loss_dict[skipped_iters_key] = (
            total_loss_dict.get(skipped_iters_key, 0) + skipped_iter
        )
        # Update losses and set nan iterations
        got_nan = False
        for key in loss_dict:
            if not skipped_iter:
                total_loss_dict[key] = (
                    total_loss_dict.get(
                        key, torch.tensor([0.0], dtype=torch.float, device="cuda")
                    )
                    + loss_dict[key]
                )
            else:
                value = loss_dict[key].float().sum().item()
                is_nan = (
                    value == float("inf") or value == -float("inf") or value != value
                )
                got_nan = got_nan or is_nan
        total_loss_dict[nan_iters_key] = total_loss_dict.get(nan_iters_key, 0) + int(
            got_nan
        )

        # Logging.
        timers_to_log = [
            "forward-backward",
            "forward-compute",
            "backward-compute",
            "batch-generator",
            "forward-recv",
            "forward-send",
            "backward-recv",
            "backward-send",
            "forward-send-forward-recv",
            "forward-send-backward-recv",
            "backward-send-forward-recv",
            "backward-send-backward-recv",
            "forward-backward-send-forward-backward-recv",
            "layernorm-grads-all-reduce",
            "embedding-grads-all-reduce",
            "all-grads-sync",
            "params-all-gather",
            "optimizer-copy-to-main-grad",
            "optimizer-unscale-and-check-inf",
            "optimizer-clip-main-grad",
            "optimizer-count-zeros",
            "optimizer-inner-step",
            "optimizer-copy-main-to-model-params",
            "optimizer",
        ]

        # Calculate batch size.
        batch_size = (
            args.micro_batch_size * args.data_parallel_size * get_num_microbatches()
        )

        # Track app tag & app tag ID
        one_logger_utils.track_app_tag(batch_size, args.world_size, args.seq_length)

        total_iterations = (
            total_loss_dict[advanced_iters_key] + total_loss_dict[skipped_iters_key]
        )

        # learning rate will be None on ranks without trainable params, so we must gather across mp ranks
        learning_rate = reduce_max_stat_across_model_parallel_group(learning_rate)
        # Tensorboard values.
        # Timer requires all the ranks to call.
        if args.log_timers_to_tensorboard and (
            iteration % args.tensorboard_log_interval == 0
        ):
            timers.write(timers_to_log, writer, iteration, normalizer=total_iterations)
        if writer and (iteration % args.tensorboard_log_interval == 0):
            if wandb_writer:
                wandb_writer.log(
                    {"samples vs steps": args.consumed_train_samples}, iteration
                )
            writer.add_scalar("learning-rate", learning_rate, iteration)
            writer.add_scalar(
                "learning-rate vs samples", learning_rate, args.consumed_train_samples
            )
            if wandb_writer:
                wandb_writer.log({"learning-rate": learning_rate}, iteration)
            if args.decoupled_lr is not None:
                writer.add_scalar(
                    "decoupled-learning-rate", decoupled_learning_rate, iteration
                )
            if args.skipped_train_samples > 0:
                writer.add_scalar(
                    "skipped-train-samples", args.skipped_train_samples, iteration
                )
                if wandb_writer:
                    wandb_writer.log(
                        {"skipped-train-samples": args.skipped_train_samples}, iteration
                    )
            writer.add_scalar("batch-size", batch_size, iteration)
            writer.add_scalar(
                "batch-size vs samples", batch_size, args.consumed_train_samples
            )
            if wandb_writer:
                wandb_writer.log({"batch-size": batch_size}, iteration)
            for key in loss_dict:
                writer.add_scalar(key, loss_dict[key], iteration)
                writer.add_scalar(
                    key + " vs samples", loss_dict[key], args.consumed_train_samples
                )
                if wandb_writer:
                    wandb_writer.log({key: loss_dict[key]}, iteration)
            if args.log_loss_scale_to_tensorboard:
                writer.add_scalar("loss-scale", loss_scale, iteration)
                writer.add_scalar(
                    "loss-scale vs samples", loss_scale, args.consumed_train_samples
                )
                if wandb_writer:
                    wandb_writer.log({"loss-scale": loss_scale}, iteration)
            if args.log_world_size_to_tensorboard:
                writer.add_scalar("world-size", args.world_size, iteration)
                writer.add_scalar(
                    "world-size vs samples",
                    args.world_size,
                    args.consumed_train_samples,
                )
                if wandb_writer:
                    wandb_writer.log({"world-size": args.world_size}, iteration)
            if grad_norm is not None:
                writer.add_scalar("grad-norm", grad_norm, iteration)
                writer.add_scalar(
                    "grad-norm vs samples", grad_norm, args.consumed_train_samples
                )
                if wandb_writer:
                    wandb_writer.log({"grad-norm": grad_norm}, iteration)
            if num_zeros_in_grad is not None:
                writer.add_scalar("num-zeros", num_zeros_in_grad, iteration)
                writer.add_scalar(
                    "num-zeros vs samples",
                    num_zeros_in_grad,
                    args.consumed_train_samples,
                )
                if wandb_writer:
                    wandb_writer.log({"num-zeros": num_zeros_in_grad}, iteration)
            if params_norm is not None:
                writer.add_scalar("params-norm", params_norm, iteration)
                writer.add_scalar(
                    "params-norm vs samples", params_norm, args.consumed_train_samples
                )
                if wandb_writer:
                    wandb_writer.log({"params-norm": params_norm}, iteration)
            if args.log_memory_to_tensorboard:
                mem_stats = torch.cuda.memory_stats()
                writer.add_scalar(
                    "mem-reserved-bytes",
                    mem_stats["reserved_bytes.all.current"],
                    iteration,
                )
                writer.add_scalar(
                    "mem-allocated-bytes",
                    mem_stats["allocated_bytes.all.current"],
                    iteration,
                )
                writer.add_scalar(
                    "mem-max-allocated-bytes",
                    mem_stats["allocated_bytes.all.peak"],
                    iteration,
                )
                writer.add_scalar(
                    "mem-allocated-count",
                    mem_stats["allocation.all.current"],
                    iteration,
                )
        if args.num_experts is not None:
            moe_loss_scale = 1 / get_num_microbatches()
            track_names = []
            if args.moe_router_load_balancing_type in ["aux_loss", "seq_aux_loss"]:
                track_names.append("load_balancing_loss")
            if args.moe_z_loss_coeff is not None:
                track_names.append("z_loss")
            track_moe_metrics(
                loss_scale=moe_loss_scale,
                iteration=iteration,
                writer=writer,
                wandb_writer=wandb_writer,
                total_loss_dict=total_loss_dict,
                per_layer_logging=args.moe_per_layer_logging,
                force_initialize=True,
                track_names=track_names,
                num_layers=args.num_layers,
                moe_layer_freq=args.moe_layer_freq,
            )
        if args.mtp_num_layers is not None:
            mtp_loss_scale = 1 / get_num_microbatches()
            MTPLossLoggingHelper.track_mtp_metrics(
                mtp_loss_scale, iteration, writer, wandb_writer, total_loss_dict
            )
        if iteration % args.log_interval == 0 or iteration == 1:
            if args.record_memory_history and is_last_rank():
                snapshot = torch.cuda.memory._snapshot()
                from pickle import dump

                with open(args.memory_snapshot_path, "wb") as f:
                    dump(snapshot, f)

            elapsed_time = timers("interval-time").elapsed(barrier=True)
            elapsed_time_per_iteration = elapsed_time / total_iterations

            throughput = num_floating_point_operations(args, batch_size) / (
                elapsed_time_per_iteration * 10**12 * args.world_size
            )

            one_logger_utils.track_e2e_metrics(args.log_throughput, throughput)

            if args.log_timers_to_tensorboard:
                if writer:
                    writer.add_scalar(
                        "iteration-time", elapsed_time_per_iteration, iteration
                    )
                if wandb_writer:
                    wandb_writer.log(
                        {"iteration-time": elapsed_time_per_iteration}, iteration
                    )
            log_string = f" [{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}]"
            log_string += " iteration {:8d}/{:8d} |".format(iteration, args.train_iters)
            log_string += " consumed samples: {:12d} |".format(
                args.consumed_train_samples
            )
            if args.skipped_train_samples > 0:
                log_string += " skipped samples: {:12d} |".format(
                    args.skipped_train_samples
                )
            log_string += " elapsed time per iteration (ms): {:.1f} |".format(
                elapsed_time_per_iteration * 1000.0
            )
            if args.log_throughput:
                log_string += f" throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |"
                if args.log_timers_to_tensorboard:
                    if writer:
                        writer.add_scalar("throughput", throughput, iteration)
                    if wandb_writer:
                        wandb_writer.log({"throughput": throughput}, iteration)
            # Decoupled_learning_rate should be not None only on first and last pipeline stage.
            log_string += f" learning rate: {learning_rate:.6E} |"
            if args.decoupled_lr is not None and (
                mpu.is_pipeline_first_stage(ignore_virtual=True)
                or mpu.is_pipeline_last_stage(ignore_virtual=True)
            ):
                assert decoupled_learning_rate is not None
                log_string += (
                    f" decoupled learning rate: {decoupled_learning_rate:.6E} |"
                )
            else:
                assert decoupled_learning_rate is None
            log_string += f" global batch size: {batch_size:5d} |"
            for key in total_loss_dict:
                if key not in [advanced_iters_key, skipped_iters_key, nan_iters_key]:
                    avg = total_loss_dict[key].item() / float(
                        max(1, total_loss_dict[advanced_iters_key])
                    )
                    log_string += " {}: {:.6E} |".format(key, avg)
                    total_loss_dict[key] = torch.tensor(
                        [0.0], dtype=torch.float, device="cuda"
                    )
            log_string += f" loss scale: {loss_scale:.1f} |"
            if grad_norm is not None:
                log_string += f" grad norm: {grad_norm:.3f} |"
            if num_zeros_in_grad is not None:
                log_string += f" num zeros: {num_zeros_in_grad} |"
            if params_norm is not None:
                log_string += f" params norm: {params_norm:.3f} |"
            log_string += " number of skipped iterations: {:3d} |".format(
                total_loss_dict[skipped_iters_key]
            )
            log_string += " number of nan iterations: {:3d} |".format(
                total_loss_dict[nan_iters_key]
            )
            total_loss_dict[advanced_iters_key] = 0
            total_loss_dict[skipped_iters_key] = 0
            total_loss_dict[nan_iters_key] = 0
            print_rank_last(log_string)
            if report_memory_flag:
                # Report memory after optimizer state has been initialized.
                if torch.distributed.get_rank() == 0:
                    num_microbatches = get_num_microbatches()
                    report_theoretical_memory(
                        args, num_microbatches=num_microbatches, verbose=True
                    )
                report_memory(f"(after {iteration} iterations)")
                report_memory_flag = False
            timers.log(timers_to_log, normalizer=args.log_interval)

        return report_memory_flag

    training.training_log = training_log


def _patch_mla_attention():
    # support thd
    from megatron.core.utils import deprecate_inference_params
    from megatron.core import parallel_state, tensor_parallel
    from megatron.core.transformer.multi_latent_attention import (
        MultiLatentAttention,
        MLASelfAttention,
    )
    from megatron.core.tensor_parallel.mappings import (
        gather_from_sequence_parallel_region,
        gather_from_tensor_model_parallel_region,
        scatter_to_sequence_parallel_region,
    )

    # Code borrowed from NVIDIA/Megatron-LM
    def forward(
        self,
        hidden_states,
        attention_mask,
        key_value_states=None,
        inference_context=None,
        rotary_pos_emb=None,
        rotary_pos_cos=None,
        rotary_pos_sin=None,
        attention_bias=None,
        packed_seq_params=None,
        position_ids=None,
        sequence_len_offset=None,
        *,
        inference_params=None,
    ):
        """Forward pass for multi-latent attention"""
        assert attention_bias is None, "Attention bias should not be passed into MLA."
        assert (
            rotary_pos_cos is None and rotary_pos_sin is None
        ), "MLA does not support Flash Decoding"

        # hidden_states: [sq, b, h]

        inference_context = deprecate_inference_params(
            inference_context, inference_params
        )

        # =====================
        # Query, Key, and Value
        # =====================
        # Get the query, key and value tensors based on the type of attention -
        # self or cross attn.
        # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128]
        query, key, value = self.get_query_key_value_tensors(
            hidden_states,
            key_value_states,
            position_ids,
            packed_seq_params,
            rotary_pos_emb=rotary_pos_emb,
            inference_context=inference_context,
        )

        # ===================================================
        # Adjust key, value for inference
        # ===================================================
        # rotary_pos_emb = None
        query, key, value, _, attn_mask_type = self._adjust_key_value_for_inference(
            inference_context, query, key, value, rotary_pos_emb=None
        )

        # TODO: Currently, TE can only accept contiguous tensors for MLA
        query = query.contiguous()
        key = key.contiguous()
        value = value.contiguous()

        # ==================================
        # core attention computation
        # ==================================
        # Need corresponding TE change
        thd_qkv_format = packed_seq_params and packed_seq_params.qkv_format == "thd"
        v_dim = value.shape[-1]
        if thd_qkv_format and query.shape[-1] != v_dim:
            value = F.pad(value, [0, query.shape[-1] - v_dim])
            self.core_attention.hidden_size_per_attention_head_v = value.shape[-1]
        if self.checkpoint_core_attention and self.training:
            core_attn_out = self._checkpointed_attention_forward(
                query, key, value, attention_mask, packed_seq_params=packed_seq_params
            )
        else:
            core_attn_out = self.core_attention(
                query,
                key,
                value,
                attention_mask,
                packed_seq_params=packed_seq_params,
                attn_mask_type=attn_mask_type,
            )
        if thd_qkv_format:
            if core_attn_out.ndim == 2:
                core_attn_out = core_attn_out.reshape(
                    *core_attn_out.shape[:-1], -1, value.shape[-1]
                )
            if query.shape[-1] != v_dim:
                core_attn_out = core_attn_out[..., :v_dim]
            # reshape to same output shape as unpacked case
            # (t, np, hn) -> (t, b=1, h=np*hn)
            # t is the pack size = sum (sq_i)
            # note that batch is a dummy dimension in the packed case
            core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)

        if self.recompute_up_proj:
            assert self.qkv_up_checkpoint is not None
            self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out)
            self.qkv_up_checkpoint = None

        # =================
        # Output. [sq, b, h]
        # =================
        output, bias = self.linear_proj(core_attn_out)

        return output, bias
        pass

    MultiLatentAttention.forward = forward

    def get_query_key_value_tensors(
        self,
        hidden_states,
        key_value_states=None,
        position_ids=None,
        packed_seq_params=None,
        inference_context=None,
        rotary_pos_emb=None,
        *,
        inference_params=None,
    ):
        """
        Derives `query`, `key` and `value` tensors from `hidden_states`.
        """
        # s = sequence length, b = batch size, h = hidden size, n = num attention heads
        # Attention heads [s, b, n*h]
        assert (
            hidden_states.ndim == 3
        ), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D"

        inference_context = deprecate_inference_params(
            inference_context, inference_params
        )

        if packed_seq_params is not None:
            cu_seqlens_q = packed_seq_params.cu_seqlens_q
            cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
        else:
            cu_seqlens_q = cu_seqlens_kv = None

        # =========================================
        # QKV down projection and layernorm
        # =========================================
        if self.config.q_lora_rank is not None:
            # if linear_q_down_proj is ColumnParallelLinear:
            #     q_compressed: [s, b, q_lora_rank / TP]
            # elif linear_q_down_proj is Linear:
            #     q_compressed: [s / TP, b, q_lora_rank]
            q_compressed, _ = self.linear_q_down_proj(hidden_states)

            # When output is sharded (ColumnParallelLinear), two things are needed to be
            # identical to a normal Linear.
            #   1. Manually gather output to restore output dim q_lora_rank;
            #   2. Scatter sequence back to s / TP if sequence-parallel since it was
            #      gathered by ColumnParallelLinear.
            if q_compressed.size(-1) != self.config.q_lora_rank:
                q_compressed = gather_from_tensor_model_parallel_region(q_compressed)
                if self.config.sequence_parallel:
                    q_compressed = scatter_to_sequence_parallel_region(q_compressed)

            q_compressed = self.q_layernorm(q_compressed)
        else:
            q_compressed = hidden_states

        # if linear_kv_down_proj is ColumnParallelLinear:
        #     kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP]
        # elif linear_kv_down_proj is Linear:
        #     kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)]
        kv_combined, _ = self.linear_kv_down_proj(hidden_states)
        if (
            kv_combined.size(-1)
            != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim
        ):
            # kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]
            kv_combined = gather_from_tensor_model_parallel_region(kv_combined)
            # kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]
            kv_compressed, k_pos_emb = torch.split(
                kv_combined,
                [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim],
                dim=-1,
            )
            if self.config.sequence_parallel:
                # kv_compressed:[s / TP, b, kv_lora_rank]
                kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)
        else:
            # kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim]
            kv_compressed, k_pos_emb = torch.split(
                kv_combined,
                [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim],
                dim=-1,
            )
            if parallel_state.get_tensor_model_parallel_world_size() > 1:
                # k_pos_emb: [s, b, qk_pos_emb_head_dim]
                k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)

        kv_compressed = self.kv_layernorm(kv_compressed)

        # =========================================
        # QKV up projection and RoPE apply
        # =========================================
        def qkv_up_proj_and_rope_apply(
            q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb
        ):
            """
            Apply the up projection and RoPE to the query and key.
            When sequence packing enabled, the input tensors adopt a packed shape of [t, ...];
            otherwise, they maintain the unpacked shape [s, b, ...]. In subsequent code comments,
            we uniformly use [num_tokens, ...] to denote [s, b, ...] or [t, ...] for two cases.
            """
            if self.config.q_lora_rank is not None:
                # q_compressed: [num_tokens, q_lora_rank]
                # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)]
                q, _ = self.linear_q_up_proj(q_compressed)
            else:
                # q_compressed: [num_tokens, hidden_size]
                # q: [num_tokens, n * (qk_head_dim + qk_pos_emb_head_dim)]
                q, _ = self.linear_q_proj(q_compressed)

            # q: [num_tokens, n, q_head_dim]
            q = q.view(
                *q.size()[:-1], self.num_attention_heads_per_partition, self.q_head_dim
            )

            # kv: [num_tokens, n * (qk_head_dim + v_head_dim)]
            kv, _ = self.linear_kv_up_proj(kv_compressed)

            # kv: [num_tokens, n, (qk_head_dim + v_head_dim)]
            kv = kv.view(
                *kv.size()[:-1],
                self.num_attention_heads_per_partition,
                self.config.qk_head_dim + self.config.v_head_dim,
            )

            q_len = q.size()[0]
            if inference_context is not None:
                # add offset to the sequence start for inference
                sequence_start = inference_context.sequence_len_offset
                sequence_end = sequence_start + q_len
                rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]
            else:
                # Shorten rotary_pos_emb to the sequence length when inference_params
                # is not provided. This makes sure we can run forward directly with
                # any sequence length. During training, the sequence length is always
                # the full rotary_pos_emb length.
                rotary_pos_emb = rotary_pos_emb[0:q_len]

            # [num_tokens, qk_pos_emb_head_dim] -> [num_tokens, 1, qk_pos_emb_head_dim]
            k_pos_emb = torch.unsqueeze(k_pos_emb, -2)

            # q_no_pe: [num_tokens, n, qk_head_dim]
            # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim]
            q_no_pe, q_pos_emb = torch.split(
                q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1
            )

            # k_no_pe: [num_tokens, n, qk_head_dim]
            # value: [num_tokens, n, v_head_dim]
            k_no_pe, value = torch.split(
                kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1
            )
            # This function will be patched and supports mscale.
            from megatron.core.transformer.attention import apply_rotary_pos_emb

            # q_pos_emb: [num_tokens, n, qk_pos_emb_head_dim]
            q_pos_emb = apply_rotary_pos_emb(
                q_pos_emb,
                rotary_pos_emb,
                config=self.config,
                cu_seqlens=cu_seqlens_q,
            )
            # k_pos_emb:[num_tokens, 1, qk_pos_emb_head_dim]
            k_pos_emb = apply_rotary_pos_emb(
                k_pos_emb,
                rotary_pos_emb,
                config=self.config,
                cu_seqlens=cu_seqlens_kv,
            )

            # query: [num_tokens, n, (qk_head_dim + v_head_dim)]
            query = torch.cat([q_no_pe, q_pos_emb], dim=-1)

            # key: [num_tokens, n, (qk_head_dim + v_head_dim)]Add commentMore actions
            if k_pos_emb.ndim == 4:
                k_pos_emb = k_pos_emb.expand(
                    -1, -1, self.num_attention_heads_per_partition, -1
                )
            else:
                assert k_pos_emb.ndim == 3
                k_pos_emb = k_pos_emb.expand(
                    -1, self.num_attention_heads_per_partition, -1
                )
            key = torch.cat([k_no_pe, k_pos_emb], dim=-1)

            query = query.contiguous()
            key = key.contiguous()
            value = value.contiguous()
            return query, key, value

        if packed_seq_params is not None:
            # If sequence packing, TE expect [t, h, d] shaped qkv input.
            # In Megatron-Core, the qkv shape is [t, 1, h, d].
            # So we need to reshape qkv from [t, 1, h, d] to [t, h, d].
            q_compressed = q_compressed.squeeze(1)
            kv_compressed = kv_compressed.squeeze(1)
            k_pos_emb = k_pos_emb.squeeze(1)

        if self.recompute_up_proj:
            self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput()
            query, key, value = self.qkv_up_checkpoint.checkpoint(
                qkv_up_proj_and_rope_apply,
                q_compressed,
                kv_compressed,
                k_pos_emb,
                rotary_pos_emb,
            )
        else:
            query, key, value = qkv_up_proj_and_rope_apply(
                q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb
            )

        return query, key, value

    MLASelfAttention.get_query_key_value_tensors = get_query_key_value_tensors


def _patch_megatron():
    _patch_transformer_engine()
    _patch__batched_p2p_ops()
    _patch_mla_attention()
    try:
        _patch_training_log()
        logger.info("Patch training_log successfully applied.")
    except Exception:
        pass


def init_megatron_env() -> None:
    if "MEGATRON_LM_PATH" not in os.environ:
        os.environ["MEGATRON_LM_PATH"] = git_clone_github(
            "https://github.com/NVIDIA/Megatron-LM", branch="core_r0.12.0"
        )
    with safe_ddp_context(hash_id="megatron-lm"):
        if not is_megatron_available():
            subprocess_run(
                [
                    sys.executable,
                    "-m",
                    "pip",
                    "install",
                    "-e",
                    os.environ["MEGATRON_LM_PATH"],
                ]
            )
    sys.path.insert(0, os.environ["MEGATRON_LM_PATH"])
    _patch_megatron()
