# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import json
import math
import os
import sys
import time
from collections import OrderedDict
from contextlib import nullcontext
from datetime import datetime, timedelta
from functools import partial, wraps
import types
from typing import List, Optional, Tuple, Union, Dict
from torch import Tensor
import itertools
import numpy as np
from collections import defaultdict
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable import checkpoint
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from concurrent.futures import wait

import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import ReduceOp
import torch.nn.functional as F
import torch.distributed.checkpoint as dcp
import torch.nn as nn

from torch.nn import CrossEntropyLoss

from liger_kernel.ops.fused_linear_cross_entropy import (
    LigerFusedLinearCrossEntropyFunction,
)
from liger_kernel.transformers.fused_linear_cross_entropy import (
    LigerFusedLinearCrossEntropyLoss,
)
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss

from transformers.modeling_outputs import (BaseModelOutputWithPast,
                                           CausalLMOutputWithPast,
                                           MoeCausalLMOutputWithPast,
                                           SequenceClassifierOutputWithPast)

from accelerate.utils import set_module_tensor_to_device
# from datasets import load_from_disk
from mmengine import mkdir_or_exist
from mmengine.dist import infer_launcher, init_dist
from mmengine.runner import set_random_seed
from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env
from mmengine import MessageHub
from tabulate import tabulate

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import \
    apply_activation_checkpointing
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
                                                     get_state_dict,
                                                     set_state_dict)
from torch.distributed.device_mesh import init_device_mesh, DeviceMesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import _or_policy
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from torch.profiler import ProfilerActivity, profile, record_function
from torch.utils.data import ConcatDataset, DataLoader
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.modeling_utils import PreTrainedModel, load_state_dict
from transformers.utils import (SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
                                is_safetensors_available)
from transformers.utils.import_utils import (is_flash_attn_2_available,
                                             is_torch_sdpa_available)

from xtuner._lite import (AutoTokenizer, get_device, get_logger,
                          get_torch_device_module)
from xtuner._lite.accelerate import (LORA_TARGET_MAP, dispatch_hf_code, LoadWoInit,
                                     packed_sequence, varlen_attn_is_available, profile_time_and_memory)
from xtuner._lite.algorithms.sft import SftCollator, SftTokenizeFunction
from xtuner._lite.chat import CHAT_TEMPLATE_MAP
from xtuner._lite.datasets import (DATASET_CLS_MAP, OPENAI_CONVERT_MAP,
                                   SoftPackDataset, HardPackDataset, load_datasets)
from xtuner._lite.parallel import (LengthGroupedSampler, ParallelSampler,
                                   get_dp_mesh, get_sp_mesh,
                                   pad_for_sequence_parallel,
                                   reduce_sequence_parallel_loss,
                                   setup_parallel, split_for_sequence_parallel)

from xtuner._lite.parallel import (ParallelSampler, get_dp_mesh, get_fsdp_mesh,
                                   get_sp_mesh, get_tp_mesh, get_world_mesh, get_same_data_mesh,
                                   pad_for_sequence_parallel, setup_parallel,
                                   reduce_sequence_parallel_loss,
                                   split_for_sequence_parallel,
                                   get_ep_mesh, get_experts_fsdp_mesh, barrier)
from xtuner._lite.parallel.fsdp import clip_grad_norm_ as dense_clip_grad_norm_

from xpuyu.datasets import (MultiStreamingDataset, PretrainTokenizeFunction,
                            Streaming, StreamingDataset)
from xpuyu.accelerate import dispatch_hf_code
from xpuyu.parallel.megatron import megatron_internlm3_moe_casual
from xpuyu.monitor.recorder import GateRecorder, ExpertActivationRecorder


from internlm.utils.common import assert_current_device_empty
from internlm.utils.execution_time import execution_time_collecter as etc
from torch.utils.tensorboard import SummaryWriter
import threading
import queue
from torch import multiprocessing as mp

assert_current_device_empty()
with etc.collect_execute_time("import_time"):
    from internlm.core.context import ParallelMode
    from internlm.core.context import global_context as gpc
    from internlm.data.build_dataloader import (
        build_train_loader_with_data_type,
    )
    from internlm.data.utils import get_lang_subset_types
    from internlm.train.pipeline import load_new_batch_with_train_state
    from internlm.data.train_state import get_train_state
    from internlm.initialize import initialize_distributed_env
    from internlm.utils.common import (
        BatchSkipper,
        catch_error_node,
        enable_pytorch_expandable_segments,
        get_current_device,
        get_gpu_id,
        get_megatron_flops,
        launch_time,
        switch_topology_aware_rank_scheduling,
    )

from torch.distributed._tensor import Shard, distribute_tensor, Replicate, DTensor

logger = get_logger()

DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()

SUPPORT_DATA_FORMATS = OPENAI_CONVERT_MAP.keys()

def record_tensorboard(tensorboard_kargs, queue):
    writer = SummaryWriter(**tensorboard_kargs)
    i = 0
    while True:
        if not queue.empty():
            tag, value, step = queue.get()
            if tag=='over':
                writer.close()
                break
            writer.add_scalar(tag, value, step)
            i += 1
        else:
            time.sleep(0.01)

class SummaryWriterWrapper(SummaryWriter):
    def __init__(
        self,
        # tensorboard args
        log_dir=None,
        comment="",
        purge_step=None,
        max_queue=10,
        flush_secs=120,
        filename_suffix="",
        # others
        dataset_types=[],
        queue_size=3000,
        only_rank0=True,
    ):
        if only_rank0 and dist.get_rank() != 0:
            self.queue=None
            self.thread=None
        else:
            tensorboard_kargs = dict(
                log_dir=log_dir,
                comment=comment,
                purge_step=purge_step,
                max_queue=max_queue,
                flush_secs=flush_secs,
                filename_suffix=filename_suffix,
            )
            ctx=mp.get_context('spawn')        
            self.queue = ctx.Queue(maxsize=queue_size)
            self.thread = ctx.Process(
                target=record_tensorboard, args=(tensorboard_kargs, self.queue)
            )
            self.thread.start()
        self.dataset_types = dataset_types + ["undefined"] # TODO: Check this mapping

    def qsize(self):
        if self.queue is not None:
            return self.queue.qsize()
        else:
            return 0

    def add_scalar(
        self,
        tag,
        scalar_value,
        global_step=None,
        walltime=None,
        new_style=False,
        double_precision=False,
        reduce_op=None,
    ):
        if reduce_op is not None:
            scalar_value = torch.tensor(scalar_value).cuda()
            dist.all_reduce(scalar_value, op=reduce_op)
            scalar_value = scalar_value.item()
        if self.thread is not None:
            self.queue.put((tag, scalar_value, global_step))

    def add_train_dynamics(self, loss, unreduced_loss, correct_preds, batch, steps):
        self.add_scalar("train_loss/total_loss", loss, global_step=steps)
        # loss per class type
        unreduced_loss = torch.cat(unreduced_loss, dim=0).flatten()  # (B T-1)
        type_ids = batch[0]["type_ids"].to(unreduced_loss.device)  # B T
        type_ids = type_ids[:, :-1].flatten()  # (B T-1)
        type_ids[type_ids == -1] = len(self.dataset_types) - 1
        loss_scatter = torch.zeros(
            [len(self.dataset_types)],
            device=unreduced_loss.device,
            dtype=unreduced_loss.dtype,
        )
        count = torch.bincount(type_ids, minlength=len(self.dataset_types))
        loss_scatter.scatter_add_(0, type_ids, unreduced_loss)
        loss_scatter = loss_scatter / (count + 1e-6)

        loss_scatter = loss_scatter.tolist()
        for i, loss in enumerate(loss_scatter):
            self.add_scalar(
                f"train_loss/{self.dataset_types[i]}", loss, global_step=steps
            )
        # acc per class type
        correct_preds = (
            torch.cat(correct_preds, dim=-1).flatten().to(unreduced_loss.dtype)
        )
        right_number = torch.zeros(
            [len(self.dataset_types)],
            device=unreduced_loss.device,
            dtype=unreduced_loss.dtype,
        )
        right_number.scatter_add_(0, type_ids, correct_preds)
        acc = right_number / (count + 1e-6)
        acc = acc.tolist()
        for i, acc_per_type in enumerate(acc):
            self.add_scalar(
                f"train_acc/{self.dataset_types[i]}", acc_per_type, global_step=steps
            )
        total_acc = right_number.sum() / (count.sum() + 1e-6)
        self.add_scalar(f"train_acc/total", total_acc, global_step=steps)

    def add_optimize_info(self, grad_norm, train_state, steps):
        self.add_scalar("optimize/grad_norm", grad_norm, global_step=steps)
        self.add_scalar(
            "optimize/inf_nan_skip_batches",
            train_state.inf_nan_skip_batches,
            global_step=steps,
        )

    def add_data_infos(self, batch, train_state, step):
        # tokens for classes
        type_ids = batch[0]["type_ids"]  # B L
        type_ids[type_ids == -1] = len(self.dataset_types) - 1
        count = torch.bincount(type_ids.flatten(), minlength=len(self.dataset_types))
        count = dict(
            (f"{self.dataset_types[i]}", v) for i, v in enumerate(count.tolist())
        )
        for k, v in count.items():
            self.add_scalar("data_tokens/" + k, v, step)

        # epochs for subsets
        used_epochs = train_state.data_state_dict["used_epochs"]
        for file_name, e in used_epochs.items():
            self.add_scalar(
                f"data_subset_epochs_rank0/{file_name}", e, step, reduce_op=None
            )  # only in rank 0

    def add_speed_info(self,tgs,e2e_tgs,step):
        self.add_scalar("speed/tgs", tgs, step,reduce_op=None)
        self.add_scalar("speed/e2e_tgs", e2e_tgs, step,reduce_op=None)
        self.add_scalar("speed/tb_qsize", self.qsize(), step,reduce_op=None)

    def add_moe_info(self, model_level_info, per_layer_info, step):
        # model info
        for key, value in model_level_info.items():
            self.add_scalar(f"moe/{key}", value, step)

        # per layer info
        for key, value in per_layer_info.items():
            for layer_id, v in enumerate(value):
                self.add_scalar(f"moe_{key}/{layer_id}", v, step)
    def close(self):
        if self.queue is not None:
            self.queue.put(('over',0,0))

def log_format(rank, debug=False):

    sp_rank = get_sp_mesh().get_local_rank()
    dp_rank = get_dp_mesh().get_local_rank()
    tp_rank = get_tp_mesh().get_local_rank()
    fsdp_rank = get_fsdp_mesh().get_local_rank()

    formatter = f'[XTuner][RANK {rank}][DP {dp_rank}][SP {sp_rank}][TP {tp_rank}]'
    formatter += '[{time:YYYY-MM-DD HH:mm:ss}][<level>{level}</level>]'

    if debug:
        formatter += '[<cyan>{name}</cyan>:'
        formatter += '<cyan>{function}</cyan>:'
        formatter += '<cyan>{line}</cyan>]'

    formatter += ' <level>{message}</level>'
    return formatter


def parse_args():
    parser = argparse.ArgumentParser(description='Train LLM')

    model_args = parser.add_argument_group('model', 'Group 1 description')
    model_args.add_argument('--llm', help='config file name or path.')
    model_args.add_argument('--train-cfg', help='interntrain config file')
    model_args.add_argument(
        '-t',
        '--tokenizer',
        help=('repo id or local path of the tokenizer. '
              'Defaults to the same as `model`'))
    model_args.add_argument('--load-pretrain', action='store_true')

    model_args.add_argument(
        '--dtype',
        default='auto',
        choices=['fp16', 'bf16', 'auto'],
        help=("the dtype of the model forward. When set to 'auto', it will "
              'automatically determine whether bf16 is available, '
              'prioritizing the use of bf16.'))
    model_args.add_argument(
        '--selective-recompute',
        default=1.0,
        type=float,
        help=('the ratio of re-computation for transforemer layers. '
              'The maximum is 1; the larger the value, the less memory '
              'required for training. The default is 1, meaning all layers '
              'need to be re-computated.'))
    model_args.add_argument('--cpu-offload', action='store_true', help=(''))
    model_args.add_argument(
        '--shard-strategy',
        default='full',
        choices=['full', 'hybrid', 'zero2', 'no', 'hybrid_zero2'],
        help=('The sharding strategy to be used for distributed training.'))

    custom_model_args = parser.add_argument_group('model',
                                                  'Custom model structure')
    custom_model_args.add_argument('--hidden-size', type=int, default=None)
    custom_model_args.add_argument(
        '--num-attention-heads', type=int, default=None)
    custom_model_args.add_argument(
        '--num-key-value-heads', type=int, default=None)
    custom_model_args.add_argument(
        '--intermediate-size', type=int, default=None)
    custom_model_args.add_argument(
        '--num-hidden-layers', type=int, default=None)
    custom_model_args.add_argument(
        '--vocab-size', type=int, default=None)
    custom_model_args.add_argument(
        '--n-shared-experts', type=int, default=None)
    custom_model_args.add_argument(
        '--num-experts-per-tok', type=int, default=None)
    custom_model_args.add_argument(
        '--num-routed-experts', type=int, default=None)
    custom_model_args.add_argument(
        '--head-dim', type=int, default=None)
    custom_model_args.add_argument(
        '--aux-loss-alpha', type=float, default=None)

    dist_args = parser.add_argument_group('dist', 'Group 1 description')
    dist_args.add_argument('--sp-size', type=int, default=1, help='')
    dist_args.add_argument('--ep-size', type=int, default=1, help='')

    optim_args = parser.add_argument_group('optimizer', 'Group 1 description')
    optim_args.add_argument(
        '--max-grad-norm', default=1, type=float, help='gradient clipping')
    parser.add_argument(
        '--work-dir',
        default='work_dirs',
        help='the dir to save logs and models')
    parser.add_argument(
        '--checkpoint-interval',
        default=-1,
        type=float,
        help=('how many steps to save a checkpoint; it can be a floating '
              'point number less than 1, or an integer greater than or equal '
              "to 1. When it's a floating point, it will be multiplied by the "
              'total number of training steps.'))
    parser.add_argument(
        '--hf-interval',
        default=-1,
        type=float,
        help=('how many steps to save a hf model; it can be a floating '
              'point number less than 1, or an integer greater than or equal '
              "to 1. When it's a floating point, it will be multiplied by the "
              'total number of training steps.'))
    parser.add_argument(
        '--max-keep-ckpts',
        type=int,
        default=-1,
        help='the maximum number of checkpoints to keep.')
    parser.add_argument(
        '--checkpoint-drop-optimizer',
        action='store_true',
        help=('only model parameters are saved when saving a checkpoint. '
              'This can significantly reduce the size of checkpoint files, '
              'but the saved checkpoints cannot be resumed.'))
    parser.add_argument('--log-interval', default=1, type=int)
    # parser.add_argument('--val-interval', default=5000, type=int)
    parser.add_argument(
        '--resume', action='store_true', help='resume from the last checkpoint')
    parser.add_argument(
        '--resume-from',
        type=str,
        default=None,
        help='specify checkpoint path to be resumed from.')
    parser.add_argument(
        '--seed', type=int, default=0, help='Random seed for the training')
    parser.add_argument(
        '--debug', action='store_true', help='Set logger level to `DEBUG`')
    parser.add_argument(
        '--port', type=int, default=8888, help='port')
    parser.add_argument(
        '--reshard-after-forward', type=int, default=-1, help='port')
    parser.add_argument('--use-hsdp', action='store_true')
    parser.add_argument(
        "--tensorboard", default=None, type=str, help="tensorboard log dir"
    )
    args = parser.parse_args()
    return args


def is_interval(step, total_steps, interval):
    return step % interval == 0 or step == total_steps


def rank0_first(func):

    @wraps(func)
    def wrapper(*args, **kwargs):
        if dist.get_rank() == 0:
            result = func(*args, **kwargs)

        barrier()

        if dist.get_rank() != 0:
            result = func(*args, **kwargs)

        barrier()
        return result

    return wrapper


@rank0_first
def build_config(args):
    llm_cfg = AutoConfig.from_pretrained(args.llm, trust_remote_code=True)
    return llm_cfg


def cal(llm, cfg):
    numel_act = 0
    numel_total = 0
    numel_moe = 0
    numel_wo_moe = 0
    numel_attn = 0
    numel_act_moe = 0
    for name, param in llm.named_parameters():
        if 'expert' in name:
            numel_moe += param.numel()
        else:
            numel_wo_moe += param.numel()
        if '.experts.' in name:
            numel_act += param.numel(
            ) * cfg.num_experts_per_tok / (cfg.num_routed_experts if hasattr(cfg, 'num_routed_experts') else cfg.n_routed_experts)
        else:
            numel_act += param.numel()
        if 'attention' in name or 'self_attn' in name:
            numel_attn += param.numel()
        numel_total += param.numel()
    print(
        f'Total act param: {numel_act / 1e9}, Total param: {numel_total / 1e9}, MoE param: {numel_moe / 1e9}, '
        f'Other param: {numel_wo_moe / 1e9}, Attn param: {numel_attn / 1e9}, MoE act param: {(numel_act - numel_wo_moe) / 1e9}'
    )


def build_llm_model(args, config, dtype=torch.float32):
    if args.load_pretrain:
        with LoadWoInit():
            llm = AutoModelForCausalLM.from_pretrained(
                args.llm,
                trust_remote_code=True,
                torch_dtype=dtype,
                attn_implementation=config.attn_implementation)
    else:
        cfg = copy.deepcopy(config)

        if args.hidden_size is not None:
            cfg.hidden_size = args.hidden_size
        if args.intermediate_size is not None:
            cfg.intermediate_size = args.intermediate_size
        if args.num_attention_heads is not None:
            cfg.num_attention_heads = args.num_attention_heads
        if args.num_key_value_heads is not None:
            cfg.num_key_value_heads = args.num_key_value_heads
        if args.num_hidden_layers is not None:
            cfg.num_hidden_layers = args.num_hidden_layers
        if args.vocab_size is not None:
            cfg.vocab_size = args.vocab_size
        # qwen 用的是 shared_expert_intermediate_size
        if args.n_shared_experts is not None:
            cfg.n_shared_experts = args.n_shared_experts
        if args.num_experts_per_tok is not None:
            cfg.num_experts_per_tok = args.num_experts_per_tok
        if args.num_routed_experts is not None:
            cfg.num_routed_experts = args.num_routed_experts
        if args.head_dim is not None:
            cfg.head_dim = args.head_dim
        if args.aux_loss_alpha is not None:
            cfg.aux_loss_alpha = args.aux_loss_alpha

        llm = AutoModelForCausalLM.from_config(
            config=cfg,
            trust_remote_code=True,
            torch_dtype=config.torch_dtype,
            attn_implementation='flash_attention_2')
        
    if dist.get_rank() == 0:
        cal(llm, llm.config)

    # Ensure all numerical values in the optimizer are fp32.
    # FSDP will use low precision during forward.
    llm.to(dtype)
    llm.config.use_cache = False

    return llm


# from xpuyu.modelings.internlm_moe.modeling_internlm3_moe import load_balancing_loss_func

# def internlm3_moe_forward_fused_linear_ce(
#     self,
#     input_ids: torch.LongTensor = None,
#     attention_mask: Optional[torch.Tensor] = None,
#     position_ids: Optional[torch.LongTensor] = None,
#     past_key_values: Optional[List[torch.FloatTensor]] = None,
#     inputs_embeds: Optional[torch.FloatTensor] = None,
#     labels: Optional[torch.LongTensor] = None,
#     use_cache: Optional[bool] = None,
#     output_attentions: Optional[bool] = None,
#     output_hidden_states: Optional[bool] = None,
#     output_router_logits: Optional[bool] = None,
#     return_dict: Optional[bool] = None,
#     cache_position: Optional[torch.LongTensor] = None,
#     num_logits_to_keep: int = 0,
#     **loss_kwargs,
# ):
#     output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
#     output_router_logits = (
#         output_router_logits if output_router_logits is not None else self.config.output_router_logits
#     )
#     output_hidden_states = (
#         output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
#     )
#     return_dict = return_dict if return_dict is not None else self.config.use_return_dict

#     # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
#     outputs = self.model(
#         input_ids=input_ids,
#         attention_mask=attention_mask,
#         position_ids=position_ids,
#         past_key_values=past_key_values,
#         inputs_embeds=inputs_embeds,
#         use_cache=use_cache,
#         output_attentions=output_attentions,
#         output_hidden_states=output_hidden_states,
#         output_router_logits=output_router_logits,
#         return_dict=return_dict,
#         cache_position=cache_position,
#     )

#     hidden_states = outputs[0]

#     loss = None
#     logits = None

#     if self.training and (labels is not None):
#         shift_hidden_states = hidden_states[..., :-1, :].contiguous()
#         shift_labels = labels[..., 1:].contiguous()

#         # flatten tokens
#         shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
#         shift_labels = shift_labels.view(-1)

#         lce = LigerFusedLinearCrossEntropyLoss()
#         loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)

#     else:
#         logits = self.lm_head(hidden_states)
#         if labels is not None:
#             # Upcast to float if we need to compute the loss to avoid potential precision issues
#             logits = logits.float()
#             # Shift so that tokens < n predict n
#             shift_logits = logits[..., :-1, :].contiguous()
#             shift_labels = labels[..., 1:].contiguous()
#             # Flatten the tokens
#             loss_fct = CrossEntropyLoss()
#             shift_logits = shift_logits.view(-1, self.config.vocab_size)
#             shift_labels = shift_labels.view(-1)
#             # Enable model parallelism
#             shift_labels = shift_labels.to(shift_logits.device)
#             loss = loss_fct(shift_logits, shift_labels)

#     aux_loss = None
#     if output_router_logits:
#         aux_loss = load_balancing_loss_func(
#             outputs.router_logits if return_dict else outputs[-1],
#             self.num_routed_experts,
#             self.num_experts_per_tok,
#             attention_mask,
#         )
#         if labels is not None:
#             loss += self.router_aux_loss_coef * aux_loss.to(loss.device)  # make sure to reside in the same device

#     if not return_dict:
#         output = (logits,) + outputs[1:]
#         if output_router_logits:
#             output = (aux_loss,) + output
#         return (loss,) + output if loss is not None else output

#     return MoeCausalLMOutputWithPast(
#         loss=loss,
#         aux_loss=aux_loss,
#         logits=logits,
#         past_key_values=outputs.past_key_values,
#         hidden_states=outputs.hidden_states,
#         attentions=outputs.attentions,
#         router_logits=outputs.router_logits,
#     )


@torch.no_grad()
def reduce_ep_grad(llm, ep_size):
    for module in llm.modules():
        if type(module).__name__ == 'GroupedLinear':
            if module.w1w3.grad is not None:
                module.w1w3.grad.div_(ep_size)
            if module.w2.grad is not None:
                module.w2.grad.div_(ep_size)


from torch.nn.utils.clip_grad import _no_grad
from torch.utils._foreach_utils import (
    _device_has_foreach_support,
    _group_tensors_by_device_and_dtype,
    _has_foreach_support,
)


@_no_grad
def clip_grad_norm_(
    params,
    fsdp_mesh,
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
    foreach: Optional[bool] = None,
) -> torch.Tensor:
    if isinstance(params, torch.Tensor):
        params = [params]
    params_grads = [p.grad for p in params if p.grad is not None]
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    if len(params_grads) == 0:
        return torch.tensor(0.0)
    first_device = params_grads[0].device
    grouped_params_grads: Dict[
        Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
    ] = _group_tensors_by_device_and_dtype(
        [params_grads]
    )
    params_norms: List[Tensor] = []

    for (device, _), ([device_grads], _) in grouped_params_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            params_norms.extend(torch._foreach_norm(device_grads, norm_type))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            params_norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])

    local_sharded_params_norms = torch.linalg.vector_norm(
        torch.stack([norm.to_local().to(first_device) for norm in params_norms]), norm_type, dtype=torch.float32
    )

    if norm_type == 2:
        total_sharded_params_norms = local_sharded_params_norms**norm_type
        dist.all_reduce(total_sharded_params_norms, group=fsdp_mesh.get_group(mesh_dim=-1))
        total_norm = total_sharded_params_norms ** 0.5
    else:
        raise NotImplementedError

    if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
        raise RuntimeError(
            f"The total norm of order {norm_type} for gradients from "
            "`parameters` is non-finite, so it cannot be clipped. To disable "
            "this error and scale the gradients by the non-finite norm anyway, "
            "set `error_if_nonfinite=False`"
        )
    clip_coef = max_norm / (total_norm + 1e-6)
    # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
    # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
    # when the gradients do not reside in CPU memory.
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

    for (device, _), ([device_grads], _) in grouped_params_grads.items():  # type: ignore[assignment]
        if (foreach is None and _has_foreach_support(device_grads, device)) or (
            foreach and _device_has_foreach_support(device)
        ):
            torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
        elif foreach:
            raise RuntimeError(
                f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
            )
        else:
            clip_coef_clamped_device = clip_coef_clamped.to(device)
            for g in device_grads:
                g.mul_(clip_coef_clamped_device)

    return total_norm


# @_no_grad
# def clip_grad_norm_(
#     moe_params,
#     non_moe_params,
#     experts_fsdp_mesh,
#     max_norm: float,
#     norm_type: float = 2.0,
#     error_if_nonfinite: bool = False,
#     foreach: Optional[bool] = None,
# ) -> torch.Tensor:
#     if isinstance(moe_params, torch.Tensor):
#         moe_params = [moe_params]
#     if isinstance(non_moe_params, torch.Tensor):
#         non_moe_params = [non_moe_params]
#     moe_grads = [p.grad for p in moe_params if p.grad is not None]
#     non_moe_grads = [p.grad for p in non_moe_params if p.grad is not None]
#     max_norm = float(max_norm)
#     norm_type = float(norm_type)
#     if len(moe_grads) + len(non_moe_grads) == 0:
#         return torch.tensor(0.0)
#     first_device = non_moe_grads[0].device
#     grouped_moe_grads: Dict[
#         Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
#     ] = _group_tensors_by_device_and_dtype(
#         [moe_grads]
#     )
#     grouped_non_moe_grads: Dict[
#         Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
#     ] = _group_tensors_by_device_and_dtype(
#         [non_moe_grads]
#     )
#     moe_norms: List[Tensor] = []
#     non_moe_norms: List[Tensor] = []

#     for (device, _), ([device_grads], _) in grouped_moe_grads.items():  # type: ignore[assignment]
#         if (foreach is None and _has_foreach_support(device_grads, device)) or (
#             foreach and _device_has_foreach_support(device)
#         ):
#             moe_norms.extend(torch._foreach_norm(device_grads, norm_type))
#         elif foreach:
#             raise RuntimeError(
#                 f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
#             )
#         else:
#             moe_norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
    
#     for (device, _), ([device_grads], _) in grouped_non_moe_grads.items():  # type: ignore[assignment]
#         if (foreach is None and _has_foreach_support(device_grads, device)) or (
#             foreach and _device_has_foreach_support(device)
#         ):
#             non_moe_norms.extend(torch._foreach_norm(device_grads, norm_type))
#         elif foreach:
#             raise RuntimeError(
#                 f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
#             )
#         else:
#             non_moe_norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
    
#     local_sharded_moe_norm = torch.linalg.vector_norm(
#         torch.stack([norm.to_local().to(first_device) for norm in moe_norms]), norm_type, dtype=torch.float32
#     )
#     local_sharded_non_moe_norm = torch.linalg.vector_norm(
#         torch.stack([norm.to_local().to(first_device) for norm in non_moe_norms]), norm_type, dtype=torch.float32
#     )

#     if norm_type == 2:
#         total_sharded_moe_norm = local_sharded_moe_norm**norm_type
#         total_sharded_non_moe_norm = local_sharded_non_moe_norm**norm_type
#         dist.all_reduce(total_sharded_moe_norm)
#         dist.all_reduce(total_sharded_non_moe_norm, group=experts_fsdp_mesh.get_group(mesh_dim=0))
#         total_norm = (total_sharded_moe_norm + total_sharded_non_moe_norm) ** 0.5
#     else:
#         raise NotImplementedError

#     if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
#         raise RuntimeError(
#             f"The total norm of order {norm_type} for gradients from "
#             "`parameters` is non-finite, so it cannot be clipped. To disable "
#             "this error and scale the gradients by the non-finite norm anyway, "
#             "set `error_if_nonfinite=False`"
#         )
#     clip_coef = max_norm / (total_norm + 1e-6)
#     # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
#     # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
#     # when the gradients do not reside in CPU memory.
#     clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
#     for (device, _), ([device_grads], _) in grouped_moe_grads.items():  # type: ignore[assignment]
#         if (foreach is None and _has_foreach_support(device_grads, device)) or (
#             foreach and _device_has_foreach_support(device)
#         ):
#             torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
#         elif foreach:
#             raise RuntimeError(
#                 f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
#             )
#         else:
#             clip_coef_clamped_device = clip_coef_clamped.to(device)
#             for g in device_grads:
#                 g.mul_(clip_coef_clamped_device)
    
#     for (device, _), ([device_grads], _) in grouped_non_moe_grads.items():  # type: ignore[assignment]
#         if (foreach is None and _has_foreach_support(device_grads, device)) or (
#             foreach and _device_has_foreach_support(device)
#         ):
#             torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
#         elif foreach:
#             raise RuntimeError(
#                 f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
#             )
#         else:
#             clip_coef_clamped_device = clip_coef_clamped.to(device)
#             for g in device_grads:
#                 g.mul_(clip_coef_clamped_device)

#     return total_norm


def prepare_for_monitor(model):
    gate_recorder = GateRecorder(source='.gate')
    expert_activation_recorder = ExpertActivationRecorder(source='.experts')
    gate_recorder.prepare_model(model)
    expert_activation_recorder.prepare_model(model)
    return gate_recorder, expert_activation_recorder


# @logger.catch
def main(args):
    ###########################################################################
    #                           1. Environment                                #
    ###########################################################################

    with etc.collect_execute_time("init_comm_time"):
        catch_error_node(initialize_distributed_env)(
            config=args.train_cfg,
            launcher='torch',
            master_port=args.port,
            seed=args.seed,
            old_config=True
        )
            
    assert hasattr(gpc, "config") and gpc.config is not None

    # train_folder = gpc.config.data.train_folder
    # dataset_types, dataset_subset_types = get_lang_subset_types(train_folder)
    data_rank = gpc.get_local_rank(ParallelMode.DATA)
    data_world_size = gpc.get_world_size(ParallelMode.DATA)

    setup_parallel(sp_size=1, tp_size=1, ep_size=1)
    set_random_seed(args.seed)

    dp_mesh = get_dp_mesh()
    tp_mesh = get_tp_mesh()
    sp_mesh = get_sp_mesh()
    ep_mesh = get_ep_mesh()
    experts_fsdp_mesh = get_experts_fsdp_mesh()
    world_mesh = get_world_mesh()

    if ep_mesh.size() > 1:
        raise NotImplementedError

    dp_size = dp_mesh.size()
    sp_size = sp_mesh.size()
    tp_size = tp_mesh.size()
    world_size = world_mesh.size()

    if args.use_hsdp:
        hsdp_device_mesh = init_device_mesh(
            DEVICE, (world_size // 8, 8), mesh_dim_names=('internode', 'intranode'))
    else:
        hsdp_device_mesh = None
    
    print(hsdp_device_mesh)

    rank = dist.get_rank()

    mkdir_or_exist(args.work_dir)

    log_file = os.path.join(args.work_dir, f'all.log')
    vis_data_file = os.path.join(args.work_dir, 'vis_data.jsonl')
    dataset_types, dataset_subset_types = get_lang_subset_types(
        gpc.config.data.train_folder
    )

    # Change the log format printed in the terminal
    lvl = 'DEBUG' if args.debug else 'INFO'
    logger.add(sys.stderr, level=lvl, format=log_format(rank, args.debug))
    # Change the format saved in the log file
    logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)

    logger.info(args)
    if rank == 0:
        env = collect_env()
        import transformers

        import xtuner
        env['Transformers'] = transformers.__version__
        env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
        runtime_env = OrderedDict()
        runtime_env.update(env)
        runtime_env['Seed'] = args.seed
        runtime_env['World Size'] = world_size
        runtime_env['DP Size'] = dp_size
        runtime_env['SP Size'] = sp_size
        # runtime_env['Distributed launcher'] = dist_launcher

        runtime_env_info = '\n    ' + '\n    '.join(
            f'{k}: {v}' for k, v in runtime_env.items())
        dash_line = '-' * 60
        logger.info('\n' + dash_line + '\nRuntime environment:' +
                    runtime_env_info + '\n' + dash_line + '\n')
    # -------------------    Environment  End  ------------------------------ #
    if args.resume_from and args.resume is False:
        args.resume = True
    if args.resume is True and args.resume_from is None:
        # find last checkpoint
        ckpt_dirs = [d for d in os.listdir(args.work_dir) if
                     os.path.isdir(os.path.join(args.work_dir, d)) and d.startswith('ckpt-')]
        if len(ckpt_dirs) > 0:
            ckpt_dirs.sort(reverse=True)
            is_success = False
            for ckpt_dir in ckpt_dirs:
                if os.path.exists(os.path.join(args.work_dir, ckpt_dir, '.metadata')):
                    args.resume_from = os.path.join(args.work_dir, ckpt_dir)
                    is_success = True
                    break
                else:
                    os.system(f'rm -rf {os.path.join(args.work_dir, ckpt_dir)}')
            if is_success is False:
                logger.warning('Did not find last_checkpoint to be resumed. training from scratch.')
                args.resume = False
        else:
            logger.warning('Did not find last_checkpoint to be resumed. training from scratch.')
            args.resume = False
    if args.resume:
        assert not args.checkpoint_drop_optimizer, '`resume` and `checkpoint_drop_optimizer` cannot be set at the same time.'

    ###########################################################################
    #                     replace config                                     #
    ###########################################################################
    args.wd = gpc.config.adam.weight_decay
    args.lr = gpc.config.adam.lr
    args.adam_beta1 = gpc.config.adam.adam_beta1
    args.adam_beta2 = gpc.config.adam.adam_beta2
    args.adam_epsilon = gpc.config.adam.adam_eps
    args.total_steps = gpc.config.data.total_steps
    args.iters_per_step = gpc.config.data.gradient_accumulation
    args.warmup_ratio = gpc.config.lr_scheduler.warmup_steps
    args.lr_min = gpc.config.MIN_LEARNING_RATE
    logger.info(args)
    logger.info(f"data_rank: {data_rank}, data_world_size: {data_world_size}")
    
    ###########################################################################
    #                     2. Dataset & Dataloader                             #
    ###########################################################################

    start_load_data_t = time.time()

    assert varlen_attn_is_available()

    with etc.collect_execute_time("load_data_time"):
        train_dl = build_train_loader_with_data_type(
            data_cfg=gpc.config.data,
            data_rank=data_rank,
            data_world_size=data_world_size,
        )
    train_state = get_train_state(train_dl)

    load_data_cost_time = time.time() - start_load_data_t
    logger.info(f'[Dataset & Dataloader] Cost {load_data_cost_time:.2f}s')

    # -------------------    Dataset & Dataloader  End  --------------------- #

    ###########################################################################
    #                          3. FSDP                                        #
    ###########################################################################

    start_model_t = time.time()

    if args.dtype == 'auto':
        args.dtype = 'bf16' if torch.cuda.is_bf16_supported() else 'fp16'

    if args.dtype == 'fp16':
        dtype = torch.float16
        autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
        scaler = ShardedGradScaler()
    elif args.dtype == 'bf16':
        if torch.cuda.is_bf16_supported():
            dtype = torch.bfloat16
            autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype)
            scaler = None
        else:
            raise RuntimeError('The device does not support `bf16`, '
                               'please set `dtype` to `fp16`.')
    else:
        raise RuntimeError('`dtype` only supports `fp16`, `bf16` or `auto`, '
                           f'but found {args.dtype}.')

    llm_cfg = build_config(args)
    if is_flash_attn_2_available():
        llm_cfg.attn_implementation = 'flash_attention_2'
    elif is_torch_sdpa_available():
        llm_cfg.attn_implementation = 'sdpa'

    llm_cfg.use_cache = False
    llm_cfg.torch_dtype = dtype

    # Only load parameters on rank 0 to avoid each rank repeatedly loading the
    # same model into the CPU, wasting memory
    xtuner_load_timeout = timedelta(minutes=60)
    group_gloo = dist.new_group(backend='gloo', timeout=xtuner_load_timeout)

    if rank == 0:
        with torch.device('cpu'):
            rank0_llm = build_llm_model(args, llm_cfg, dtype)
    else:
        rank0_llm = None
    
    dist.monitored_barrier(group=group_gloo, timeout=xtuner_load_timeout)
    logger.info('after barrier')

    with torch.device('meta'):
        llm = build_llm_model(args, llm_cfg, dtype=dtype)
        dispatch_hf_code(llm)
        for module in llm.modules():
            for p_name, param in module.named_parameters(recurse=False):
                if param.requires_grad:
                    param_fp32 = torch.nn.Parameter(
                        param.to(dtype=torch.float32))
                    setattr(module, p_name, param_fp32)
    
    # logger.info('dispatch internlm3_moe_forward_fused_linear_ce')
    # llm.forward = types.MethodType(internlm3_moe_forward_fused_linear_ce, llm)

    mp_policy = MixedPrecisionPolicy(param_dtype=dtype, reduce_dtype=dtype)

    with profile_time_and_memory('[Parallelize LLM]'):
        megatron_internlm3_moe_casual(
            llm,
            rank0_llm,
            experts_fsdp_mesh=hsdp_device_mesh if args.use_hsdp else experts_fsdp_mesh,
            ep_mesh=ep_mesh,
            mp_policy=mp_policy,
            recompute_ratio=args.selective_recompute,
            reshard_after_forward=True if args.reshard_after_forward == -1 else args.reshard_after_forward)
        
        llm.train()
    
    # gate_recorder, expert_activation_recorder = prepare_for_monitor(llm)
    
    if rank == 0:
        logger.info(llm)
    
    # --------------------------    FSDP  End  ------------------------------ #

    ###########################################################################
    #                      4. Optimizer & Scheduler                           #
    ###########################################################################

    requried_grad_params = [
        param for param in llm.parameters() if param.requires_grad
    ]

    optimizer = AdamW(
        requried_grad_params,
        lr=args.lr,
        weight_decay=args.wd,
        betas=(args.adam_beta1, args.adam_beta2),
        eps=args.adam_epsilon)

    iters_per_step = args.iters_per_step
    total_steps = args.total_steps

    if args.checkpoint_interval == -1:
        checkpoint_interval = total_steps
    elif args.checkpoint_interval < 1:
        checkpoint_interval = int(total_steps * args.checkpoint_interval)
    else:
        checkpoint_interval = int(args.checkpoint_interval)

    if args.hf_interval == -1:
        hf_interval = total_steps
    elif args.hf_interval < 1:
        hf_interval = int(total_steps * args.hf_interval)
    else:
        hf_interval = int(args.hf_interval)

    max_keep_ckpts = args.max_keep_ckpts
    if max_keep_ckpts <= 0:
        # save all checkpoints
        max_keep_ckpts = total_steps + 100000
    save_hf_ckpt_names = []
    save_pt_ckpt_names = []
    ckpt_dirs = [os.path.join(args.work_dir, d) for d in os.listdir(args.work_dir) if
                 os.path.isdir(os.path.join(args.work_dir, d)) and d.startswith('ckpt-')]
    if len(ckpt_dirs) > 0:
        ckpt_dirs.sort()
        save_pt_ckpt_names = ckpt_dirs

    hf_dirs = [os.path.join(args.work_dir, d) for d in os.listdir(args.work_dir) if
               os.path.isdir(os.path.join(args.work_dir, d)) and d.startswith('hf-')]
    if len(hf_dirs) > 0:
        hf_dirs.sort()
        save_pt_ckpt_names = hf_dirs

    if args.warmup_ratio < 1:
        warmup_steps = int(args.warmup_ratio * total_steps)
    else:
        warmup_steps = int(args.warmup_ratio)

    def warmup_fn(x):
        return x / warmup_steps if x < warmup_steps else 1

    warmup_scheduler = LambdaLR(optimizer, warmup_fn)

    cosine_scheduler = CosineAnnealingLR(
        optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min)

    dp_rank = get_dp_mesh().get_local_rank()

    # ----------------    Optimizer & Scheduler End   ----------------------- #

    if args.resume:
        logger.info(f'[Resume] Resume from {args.resume_from}')
        _options = StateDictOptions(
            cpu_offload=True, ignore_frozen_params=True)
        (shard_model_state_dict,
         shard_optimizer_state_dict) = get_state_dict(
            llm, optimizer, options=_options)
        state_dict = {
            'model': shard_model_state_dict,
            'optimizer': shard_optimizer_state_dict,
            'train_state': train_state,
            'warmup_scheduler': warmup_scheduler,
            'cosine_scheduler': cosine_scheduler
        }
        # inplace state_dict
        dcp.load(
            state_dict=state_dict,
            checkpoint_id=args.resume_from,
        )

        _options = StateDictOptions(
            cpu_offload=True, strict=False)
        set_state_dict(
            llm,
            optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optimizer"],
            options=_options
        )
        if hasattr(train_state, "batch_sampler") and not isinstance(
                train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
        ):
            sampler_states = torch.load(os.path.join(args.resume_from, "sampler.pt"))
            train_dl.batch_sampler.load_state_dict(sampler_states)
            # track the actual updates of sampler when using weighted sampling
            train_state.init_batch_sampler(train_dl.batch_sampler)

        assert hasattr(train_state, "data_state_dict") and hasattr(train_state, "batch_sampler")
        dataset_dirs = [os.path.join(args.resume_from, d) for d in os.listdir(args.resume_from) if d.startswith('dataset_')]
        dataset_dirs.sort()
        state_dict_list=[]
        for dataset_dir in dataset_dirs:
            state_dict_pre_rank=torch.load(dataset_dir)
            state_dict_list.append(state_dict_pre_rank)

        if dp_rank == 0:
            cur_dataset_consumed_tokens = state_dict_list[0].pop("dataset_consumed_tokens", {})
            train_state.data_state_dict["dataset_consumed_tokens"].update(cur_dataset_consumed_tokens)

        if len(state_dict_list) == dp_size:
            train_dl.dataset.load_state_dict(state_dict_list[dp_rank])
        else:
            if state_dict_list[0]["epochs_to_use"]:
                raise NotImplementedError(
                    "Cannot resume training if dp_size changed with `epochs_to_use` set."
                    " Try set `epochs_to_use` to None."
                )
            # if dp_rank == 0:
            #     logger.info(state_dict_list)
            # logger.info('=============================================================')
            multiple_packed_states_group: Dict[str, List[Dict]] = defaultdict(list)
            consumed_samples = defaultdict(int)
            for state_dict in state_dict_list:
                for key, value in state_dict["consumed_samples"].items():
                    consumed_samples[key] += value
                for key, value in state_dict["multiple_packed_states"].items():
                    multiple_packed_states_group[key].append(value)
            used_epochs = [state_dict["used_epochs"] for state_dict in state_dict_list]
            max_used_epochs = {k: max(d[k] for d in used_epochs) for k in used_epochs[0]}

            for key in list(multiple_packed_states_group.keys()):
                sort_metrics = [
                    (
                        state_dict["tokenization_states"]["aggregation_states"]["file_shift"],
                        state_dict["tokenization_states"]["aggregation_states"]["jsonl_states"]["line_shift"],
                        state_dict["seq_offset"],
                    )
                    for state_dict in multiple_packed_states_group[key]
                ]
                multiple_packed_states_group[key] = sorted(
                    zip(sort_metrics, multiple_packed_states_group[key]), key=lambda x: x[0]
                )[-1][-1]

            if dp_rank == 0:
                state_dict = {
                    "rng_state": np.random.RandomState(seed=args.seed).get_state(),
                    "multiple_packed_states": multiple_packed_states_group,
                    "consumed_samples": consumed_samples,
                    "used_epochs": max_used_epochs,
                }
            else:
                state_dict = {
                    "rng_state": np.random.RandomState(
                        seed=args.seed + dp_rank
                    ).get_state(),
                    "multiple_packed_states": multiple_packed_states_group,
                    "consumed_samples": {},
                    "used_epochs": max_used_epochs,
                }
            # logger.info(f" --------------- {state_dict['consumed_samples'], multiple_packed_states_group}")
            train_dl.dataset.load_state_dict(state_dict)

    # print('===============',train_state.batch_count)
    if train_state.batch_count >= total_steps:
        logger.info("Training has finished, exiting...")
        return

    gpc.train_state = train_state

    ###########################################################################
    #                          5. Training                                    #
    ###########################################################################
    if args.tensorboard is not None:
        tbwriter = SummaryWriterWrapper(log_dir=args.tensorboard+f'/rank_{dist.get_rank()}', dataset_types=dataset_types,only_rank0=not args.debug)
    else:
        tbwriter = None

    start_train_t = time.time()
    total_consumed_tokens=0
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    max_memory = torch.cuda.max_memory_allocated()
    logger.info('[Train] Begin Train Loop. The current GPU memory is '
                f'{(max_memory / 1024**3):.1f}GB')
    
    train_iter = iter(train_dl)
    time_used_by_val_and_save_ckpt = 0

    group_gloo = dist.new_group(backend='gloo')
    future = None

    for batch_count in itertools.count(train_state.batch_count):
        if train_state.batch_count >= total_steps:
            break

        if train_state.batch_count <= warmup_steps:
            warmup_scheduler.step()
            cur_lr = warmup_scheduler.get_last_lr()[0]
        else:
            cur_lr = args.lr

        torch.cuda.reset_peak_memory_stats()

        step_loss = 0
        step_balancing_loss = 0
        step_z_loss = 0
        step_data_time = 0
        step_start_t = time.time()
        step_consumed_tokens = 0

        _data_start_t = time.time()

        step_data_list = []
        rank_grad_tokens = 0

        # the first dim is grad acc step
        batch, train_iter = load_new_batch_with_train_state(train_dl=train_dl, train_iter=train_iter,
                                                            train_state=train_state)

        inputs, labels = batch
        input_ids = inputs['input_ids']
        cu_seqlens = inputs['cu_seqlens']
        assert input_ids.shape[0] == iters_per_step

        gpc.config.batch_count = batch_count
        train_state.batch_count = batch_count
        train_state.num_consumed_samples_in_epoch += len(batch[1])
        
        for _iter in range(iters_per_step):
            input_ids_iter = input_ids[_iter: _iter + 1]
            labels_iter = labels[_iter: _iter + 1]
            cu_seqlens_iter = cu_seqlens[_iter]
            num_token = cu_seqlens_iter[1:] - cu_seqlens_iter[:-1]

            if num_token[-1] == 0:
                num_token = num_token[:-1]

            rank_grad_tokens += (labels_iter >= 0).sum()
            step_data_list.append({"input_ids": input_ids_iter,
                                   "labels": labels_iter,
                                   "num_tokens": num_token})
        
        rank_grad_tokens = rank_grad_tokens.to(DEVICE)
        dist.all_reduce(rank_grad_tokens)
        global_grad_tokens = rank_grad_tokens / tp_size / sp_size

        step_data_time = time.time() - _data_start_t
        unreduced_losses=[]
        correct_preds=[]
        for _iter in range(iters_per_step):
            data = step_data_list[_iter]
            input_ids = data['input_ids'].to(DEVICE)
            labels = data['labels'].to(DEVICE)
            num_tokens = data['num_tokens'].to(DEVICE)

            if sp_size > 1:
                # `dim` is 1 as the shape of tensor is (bs, seq_len, ...)
                input_ids = pad_for_sequence_parallel(input_ids, 0, sp_mesh, dim=1)
                _num_pad = input_ids.numel() - num_tokens.sum()
                if _num_pad > 0:
                    _num_pad = torch.IntTensor([_num_pad]).to(DEVICE)
                    num_tokens = torch.cat([num_tokens, _num_pad], dim=-1)

                input_ids = split_for_sequence_parallel(
                    input_ids, dim=1, sp_mesh=sp_mesh)

                labels = pad_for_sequence_parallel(labels, -100, sp_mesh, dim=1)
                labels = split_for_sequence_parallel(
                    labels, dim=1, sp_mesh=sp_mesh)

            packed_ctx = packed_sequence(num_tokens, sp_mesh=sp_mesh)

            if args.debug:
                with profile(activities=[
                        ProfilerActivity.CPU, ProfilerActivity.CUDA
                ]) as prof:
                    with packed_ctx:

                        # gate_recorder.enable()
                        # expert_activation_recorder.enable()

                        ctx = MessageHub.get_instance('packed_sequence')
                        position_ids = ctx.get_info('position_ids')

                        output = llm(input_ids=input_ids, position_ids=position_ids, use_cache=False)

                        # gate_recorder.disable()
                        # expert_activation_recorder.disable()

                        logits = output.logits
                        loss = F.cross_entropy(logits.squeeze(), labels.squeeze(), reduction='none')  # 1, seqlen
                        
                        # collect for logging
                        unreduced_losses.append(loss.detach().clone())
                        pred = logits.argmax(dim=-1)  # B L
                        correct_pred = pred == labels
                        correct_preds.append(correct_pred)
                        
                        balancing_loss, z_loss = output.aux_loss
                        loss += llm.balancing_loss_coef * balancing_loss.to(loss.device)# + llm.router_z_loss_coef * z_loss.to(loss.device)

                        if sp_size > 1:
                            sp_group = sp_mesh.get_group()
                            sp_pt_loss = dist.nn.functional.all_gather(loss, sp_group)
                            sp_pt_labels = dist.nn.functional.all_gather(labels, sp_group)

                            loss = torch.cat(sp_pt_loss, dim=-1)
                            labels = torch.cat(sp_pt_labels, dim=-1)
                        
                        loss = loss.sum() / global_grad_tokens * dp_size

                        loss.backward()
                    
                def get_total_time(events):
                    t=sum([x.self_device_time_total  for x in events])
                    return t/1e6
                
                total_device_time=get_total_time(prof.key_averages())
                print('total_device_time',total_device_time)
                prof.export_chrome_trace(args.work_dir+f'trace/{dist.get_rank()}_{batch_count}_prof_{int(total_device_time*1)}.json')

            else:
                with packed_ctx:

                    # gate_recorder.enable()
                    # expert_activation_recorder.enable()

                    ctx = MessageHub.get_instance('packed_sequence')
                    position_ids = ctx.get_info('position_ids')

                    output = llm(input_ids=input_ids, position_ids=position_ids, use_cache=False)

                    # gate_recorder.disable()
                    # expert_activation_recorder.disable()

                    logits = output.logits
                    loss = F.cross_entropy(logits.squeeze(), labels.squeeze(), reduction='none')  # 1, seqlen
                    
                    # collect for logging
                    unreduced_losses.append(loss.detach().clone())
                    pred = logits.argmax(dim=-1)  # B L
                    correct_pred = pred == labels
                    correct_preds.append(correct_pred)
                    
                    balancing_loss, z_loss = output.aux_loss
                    loss += llm.balancing_loss_coef * balancing_loss.to(loss.device)# + llm.router_z_loss_coef * z_loss.to(loss.device)

                    if sp_size > 1:
                        sp_group = sp_mesh.get_group()
                        sp_pt_loss = dist.nn.functional.all_gather(loss, sp_group)
                        sp_pt_labels = dist.nn.functional.all_gather(labels, sp_group)

                        loss = torch.cat(sp_pt_loss, dim=-1)
                        labels = torch.cat(sp_pt_labels, dim=-1)
                    
                    loss = loss.sum() / global_grad_tokens * dp_size

                    loss.backward()
                    
                    # logit_before_gate_max, logit_before_gate_min, logit_before_gate_mean = gate_recorder.reduce_after_iter()
                    # expert_activation_min, expert_activation_max, expert_activation_mean = expert_activation_recorder.reduce_after_iter()
            
            step_loss += loss.item()
            step_balancing_loss += balancing_loss.item() / iters_per_step
            step_z_loss += z_loss.item() / iters_per_step

            step_consumed_tokens += num_tokens.sum() / sp_size / tp_size

            train_state.step_count += 1
        
        reduce_ep_grad(llm, ep_mesh.size())
        grad_norm = clip_grad_norm_(requried_grad_params, hsdp_device_mesh if args.use_hsdp else experts_fsdp_mesh, args.max_grad_norm, foreach=True)
        grad_norm = grad_norm.to_local() if isinstance(grad_norm, DTensor) else grad_norm

        if grad_norm.isnan() or grad_norm.isinf():
            train_state.inf_nan_skip_batches += 1
            logger.info("The grad norm is NaN or Inf, skip this batch.")
            optimizer.zero_grad()
        else:
            optimizer.step()
            optimizer.zero_grad()

        step_time = time.time() - step_start_t
        eta = step_time * (total_steps - train_state.batch_count - 1)
        eta = timedelta(seconds=int(eta))
        tgs = int(step_consumed_tokens / step_time)
        max_memory = torch.cuda.max_memory_allocated()
        
        total_consumed_tokens += step_consumed_tokens
        end2end_tgs = int(total_consumed_tokens / (time.time() - start_train_t - time_used_by_val_and_save_ckpt))

        # log to tensorboard
        if tbwriter is not None:
            tensorboard_start_time = time.time()
            tbwriter.add_data_infos(batch, train_state, batch_count)
            tbwriter.add_train_dynamics(step_loss, unreduced_losses, correct_preds, batch, batch_count)
            tbwriter.add_optimize_info(grad_norm.detach().clone(), train_state, batch_count)
            tbwriter.add_speed_info(tgs, end2end_tgs, batch_count)
            tbwriter.add_moe_info(
                model_level_info={
                    "balancing_loss": step_balancing_loss,
                    "z_loss": step_z_loss
                },
                per_layer_info={
                    # "logit_before_gate_max": logit_before_gate_max.tolist(),
                    # "logit_before_gate_min": logit_before_gate_min.tolist(),
                    # "logit_before_gate_mean": logit_before_gate_mean.tolist(),
                    # "expert_activation_min": expert_activation_min,
                    # "expert_activation_max": expert_activation_max,
                    # "expert_activation_mean": expert_activation_mean,
                },
                step=batch_count + 1,
            )
            tensorboard_time = time.time() - tensorboard_start_time
        else:
            tensorboard_time=-1

        if is_interval(train_state.batch_count + 1, total_steps, args.log_interval):
            logger.info(f'[Train] (Epoch 1) Step '
                        f'{train_state.batch_count + 1}/{total_steps}  '
                        f'lr: {cur_lr:.6f}  loss: {step_loss:.3f}  balancing_loss: {step_balancing_loss:.3f}  z_loss: {step_z_loss:.3f}  '
                        f'grad_norm: {grad_norm:.2f}  '
                        f'max_memory: {(max_memory / 1024 ** 3):.1f}GB  '
                        # f'num_token: {num_token.cpu().tolist()}'
                        f'text_tokens: {step_consumed_tokens}  '
                        f'tgs: {tgs}  tgs_end2end: {end2end_tgs}  data_time: {step_data_time:.2f}s  '
                        f'time: {step_time:.2f}s tb_time: {tensorboard_time:.2f}  '
                        f'inf_nan_skip: {train_state.inf_nan_skip_batches}  '
                        f'eta: {eta}')
        
        num_digits = len(str(abs(total_steps)))
        if is_interval(train_state.batch_count + 1, total_steps, hf_interval):
            time_before_save = time.time()
            DEVICE_MODULE.empty_cache()

            hf_dir = os.path.join(args.work_dir, f'hf-{train_state.batch_count + 1:0{num_digits}}')

            with profile_time_and_memory('[HF Checkpoint]'):

                if rank == 0:
                    llm_state_dict = {}

                for name, param in llm.state_dict().items():
                    if isinstance(param, DTensor):
                        with torch.no_grad():
                            full_param = param.full_tensor().cpu()
                    else:
                        full_param = param.cpu()

                    if rank == 0:
                        llm_state_dict[name] = full_param

                if rank == 0:
                    rank0_llm.load_state_dict(llm_state_dict)
                    rank0_llm.save_pretrained(hf_dir)
                    # tokenizer.save_pretrained(hf_dir)

                dist.barrier()

            if dist.get_rank() == 0:
                save_hf_ckpt_names.append(hf_dir)
                if len(save_hf_ckpt_names) > max_keep_ckpts:
                    remove_hf_ckpt_name = save_hf_ckpt_names.pop(0)
                    os.system(f'rm -rf {remove_hf_ckpt_name}')
        
            max_memory = torch.cuda.max_memory_allocated()
            logger.info('[HF Checkpoint] During saving HF checkpoint, the peak GPU '
                        f'memory is {max_memory / 1024 ** 3:.1f}GB.')
            
            time_used_by_val_and_save_ckpt += time.time() - time_before_save
        
        if is_interval(train_state.batch_count + 1, total_steps, checkpoint_interval):
            time_before_save = time.time()

            if args.checkpoint_drop_optimizer:
                logger.warning('The saved checkpoint cannot be resumed. '
                               'If you want to save a resumable checkpoint, '
                               'please remove `--checkpoint-drop-optimizer` '
                               'from the command.')
            else:
                with profile_time_and_memory('[PT Checkpoint]'):
                    ckpt_id = f'{train_state.batch_count + 1:0{num_digits}}-of-{total_steps:0{num_digits}}'
                    ckpt_dir = os.path.join(args.work_dir, f'ckpt-{ckpt_id}')
                    if dp_rank==0:
                        mkdir_or_exist(ckpt_dir)
                    dist.barrier()

                    if hasattr(train_state, "data_state_dict"):  # TODO:  tp/sp/pp
                        assert hasattr(train_state, "batch_sampler")
                        torch.save(train_state.data_state_dict, os.path.join(ckpt_dir, f"dataset_{dp_rank}.pt"))
                        
                        if dp_rank == 0:

                            if hasattr(train_state, "batch_sampler") and not isinstance(
                                    train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
                            ):
                                sampler_state = train_state.batch_sampler.state_dict()
                                torch.save(sampler_state, os.path.join(ckpt_dir, "sampler.pt"))
                        else:
                            train_state.data_state_dict["dataset_consumed_tokens"] = defaultdict(int)

                    dist.barrier()

                    with profile_time_and_memory('[PT Checkpoint Wait]'):
                        if future is not None:
                            wait([future])

                    with profile_time_and_memory('[PT Checkpoint of DCP ASYNC]'):
                        # FSDP cannot be saved via torch.save
                        # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html  # noqa: E501
                        _options = StateDictOptions(
                            cpu_offload=True, ignore_frozen_params=True)
                        (shard_model_state_dict,
                         shard_optimizer_state_dict) = get_state_dict(
                            llm, optimizer, options=_options)

                        state_dict = {
                            'model': shard_model_state_dict,
                            'optimizer': shard_optimizer_state_dict,
                            'warmup_scheduler': warmup_scheduler.state_dict(),
                            'cosine_scheduler': cosine_scheduler.state_dict(),
                            'train_state': train_state.state_dict(),
                        }
                        future = dcp.async_save(state_dict, checkpoint_id=ckpt_dir, process_group=group_gloo)

                        def send_to_oss_and_remove(future):
                            # send to oss and remove local file
                            # TODO: send to oss

                            if dist.get_rank() == 0:
                                save_pt_ckpt_names.append(ckpt_dir)
                                if len(save_pt_ckpt_names) > max_keep_ckpts:
                                    remove_pt_ckpt_name = save_pt_ckpt_names.pop(0)
                                    os.system(f'rm -rf {remove_pt_ckpt_name}')
                            # print('============send_to_oss_and_remove callback==================')

                        future.add_done_callback(send_to_oss_and_remove)
        
            max_memory = torch.cuda.max_memory_allocated()
            logger.info('[Checkpoint] During saving checkpoint, the peak GPU '
                        f'memory is {max_memory / 1024 ** 3:.1f}GB.')
            
            time_used_by_val_and_save_ckpt += time.time() - time_before_save

    train_cost_time = time.time() - start_train_t
    logger.info(f'[Train] Cost {timedelta(seconds=int(train_cost_time))}')
    # ------------------------    Training  End  ---------------------------- #
    if tbwriter is not None:
        tbwriter.close()

if __name__ == '__main__':

    args = parse_args()
    main(args)
