import os
import json
import math
import argparse
import time
from typing import Dict

import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.nn import CrossEntropyLoss

from data_utils_all_label import get_sft_dataset, collate_sft

import deepspeed
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)


# =============== 全局调试统计：确认剪掉部分梯度被置零 ===============

PRUNING_DEBUG_STATS = {
    "max_pruned_grad_before_mask": 0.0,   # 被剪位置在 mask 之前的原始梯度最大值
    "max_unpruned_grad": 0.0,             # 未剪位置的梯度最大值
    "num_pruned_grad_updates": 0,         # 有梯度流过被剪位置的 step 次数（mask 前）
}

# =============== 通用的参数 mask + grad hook 工具 ===============

def register_param_mask_with_grad_hook(tensor: torch.Tensor,
                                       mask_tensor: torch.Tensor,
                                       debug_name: str):
    """
    对一个参数张量 tensor 应用 0/1 mask，并注册 grad hook：
    - 先做一次 tensor.data *= mask（确保 pruned 位置为 0）
    - 训练过程中，grad 会被强制乘 mask，保证 pruned 位置梯度为 0
    """
    mask_on_param_dev = mask_tensor.to(tensor.device)
    tensor.data.mul_(mask_on_param_dev)

    def _hook(grad):
        if grad is None:
            return None

        m = mask_tensor.to(grad.device)

        pruned_region = grad[m == 0]
        unpruned_region = grad[m != 0]

        if pruned_region.numel() > 0:
            max_pruned = pruned_region.abs().max().item()
            if max_pruned > PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"]:
                PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"] = max_pruned
            PRUNING_DEBUG_STATS["num_pruned_grad_updates"] += 1

        if unpruned_region.numel() > 0:
            max_unpruned = unpruned_region.abs().max().item()
            if max_unpruned > PRUNING_DEBUG_STATS["max_unpruned_grad"]:
                PRUNING_DEBUG_STATS["max_unpruned_grad"] = max_unpruned

        # 真正用于参数更新的梯度 = grad * m
        return grad * m

    tensor.register_hook(_hook)


def create_scheduler(num_training_steps, optimizer, warmup_ratio):
    num_warmup_steps = int(num_training_steps * warmup_ratio)
    return get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )


def collate_fn(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0)
    labels = torch.stack([b["labels"] for b in batch], dim=0)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


# ============================= 小工具：时间格式化 =============================

def format_seconds(seconds: float) -> str:
    seconds = int(seconds)
    days, rem = divmod(seconds, 86400)
    hours, rem = divmod(rem, 3600)
    minutes, seconds = divmod(rem, 60)
    return f"{days}d {hours:02d}h {minutes:02d}m {seconds:02d}s"

def is_main_process() -> bool:
    # 单卡/未初始化分布式时，默认就是主进程
    if not dist.is_available() or not dist.is_initialized():
        return True
    return dist.get_rank() == 0


def log_once(msg: str) -> None:
    """只在 rank0 打一行日志"""
    if is_main_process():
        # 这里一定要用 print，而不是 log_once 再调自己
        print(msg, flush=True)

# ============================= FLAP mask 读取与应用 =============================
from typing import Dict, Tuple

def load_flap_masks(ckpt_dir: str) -> Tuple[Dict[str, torch.Tensor],
                                            Dict[str, torch.Tensor],
                                            Dict]:
    """
    读取 FLAP 剪枝生成的 flap_mask.json。

    约定 JSON 结构为：

        {
          "attn_heads": {
             "layer_0": [0/1, 0/1, ..., (num_heads)],
             "layer_1": [...],
             ...
          },
          "mlp_neurons": {
             "layer_0": [0/1, 0/1, ..., (mlp_dim)],
             ...
          },
          "meta": {
             "prune_type": ...,
             "metric": ...,
             "target_keep_ratio": ...,
             "num_layers": ...,
             "hidden_size": ...,
             "num_heads": ...,
             "mlp_dim": ...
          }
        }

    返回三个对象：
      - attn_head_masks:  { "model.layers.0.self_attn": tensor[num_heads], ... }
      - mlp_neuron_masks: { "model.layers.0.mlp":       tensor[mlp_dim],   ... }
      - meta: 原始 meta 字典
    """
    path = os.path.join(ckpt_dir, "flap_mask.json")
    if not os.path.exists(path):
        raise FileNotFoundError(f"[FLAP] flap_mask.json not found in {ckpt_dir}")
    with open(path, "r", encoding="utf-8") as f:
        raw = json.load(f)

    meta = raw.get("meta", {})
    attn_heads = raw.get("attn_heads", {})
    mlp_neurons = raw.get("mlp_neurons", {})

    num_layers = meta.get("num_layers", None)
    num_heads = meta.get("num_heads", None)
    mlp_dim   = meta.get("mlp_dim",   None)

    if num_layers is None:
        raise ValueError("[FLAP] 'meta' in flap_mask.json must contain 'num_layers'.")
    if num_heads is None:
        log_once("[FLAP][warn] 'num_heads' missing in meta, attn head pruning may be disabled.")
    if mlp_dim is None:
        log_once("[FLAP][warn] 'mlp_dim' missing in meta, mlp neuron pruning may be disabled.")

    attn_head_masks: Dict[str, torch.Tensor] = {}
    mlp_neuron_masks: Dict[str, torch.Tensor] = {}

    # --- 1) 解析 attention head 掩码 ---
    if num_heads is not None and attn_heads:
        for l in range(num_layers):
            key = f"layer_{l}"
            if key not in attn_heads:
                raise KeyError(f"[FLAP] attn_heads['{key}'] not found in flap_mask.json")

            head_keep_list = attn_heads[key]
            if len(head_keep_list) != num_heads:
                raise ValueError(
                    f"[FLAP] layer_{l}: len(head_keep_list)={len(head_keep_list)} "
                    f"!= num_heads={num_heads}"
                )

            head_keep_tensor = torch.tensor(head_keep_list, dtype=torch.float32)

            # 如果你的模型是 model.model.layers.{l}.self_attn，把下面这一行改成：
            #   module_name = f"model.model.layers.{l}.self_attn"
            module_name = f"model.layers.{l}.self_attn"

            attn_head_masks[module_name] = head_keep_tensor

        log_once(f"[FLAP] Loaded ATTENTION head mask for {len(attn_head_masks)} layers.")

    # --- 2) 解析 MLP neuron 掩码 ---
    if mlp_dim is not None and mlp_neurons:
        for l in range(num_layers):
            key = f"layer_{l}"
            if key not in mlp_neurons:
                raise KeyError(f"[FLAP] mlp_neurons['{key}'] not found in flap_mask.json")

            neuron_keep_list = mlp_neurons[key]
            if len(neuron_keep_list) != mlp_dim:
                raise ValueError(
                    f"[FLAP] layer_{l}: len(neuron_keep_list)={len(neuron_keep_list)} "
                    f"!= mlp_dim={mlp_dim}"
                )

            neuron_keep_tensor = torch.tensor(neuron_keep_list, dtype=torch.float32)

            # 如果你的模型是 model.model.layers.{l}.mlp，把下面改成对应路径
            module_name = f"model.layers.{l}.mlp"

            mlp_neuron_masks[module_name] = neuron_keep_tensor

        log_once(f"[FLAP] Loaded MLP neuron mask for {len(mlp_neuron_masks)} layers.")

    log_once(
        f"[FLAP] meta: prune_type={meta.get('prune_type')}, "
        f"metric={meta.get('metric')}, "
        f"target_keep_ratio={meta.get('target_keep_ratio')}"
    )

    return attn_head_masks, mlp_neuron_masks, meta

def check_pruned_mlp_zero(model, mlp_neuron_masks: Dict[str, torch.Tensor], atol=1e-6):
    module_dict = dict(model.named_modules())
    for mlp_name, neuron_mask in mlp_neuron_masks.items():
        if mlp_name not in module_dict:
            continue
        mlp = module_dict[mlp_name]
        neuron_mask = neuron_mask.float()
        mlp_dim = int(neuron_mask.numel())

        for proj_name in ["gate_proj", "up_proj", "down_proj"]:
            if not hasattr(mlp, proj_name):
                continue
            proj = getattr(mlp, proj_name)
            W = proj.weight.data    # [out_dim, in_dim]
            out_dim, in_dim = W.shape

            if proj_name in ["gate_proj", "up_proj"]:
                if out_dim != mlp_dim:
                    log_once(
                        f"[FLAP][check][warn] {mlp_name}.{proj_name}: out_dim={out_dim} != mlp_dim={mlp_dim}, skip check."
                    )
                    continue
                for i, keep in enumerate(neuron_mask):
                    if keep <= 0.0:
                        assert torch.allclose(W[i, :], torch.zeros_like(W[i, :]), atol=atol), \
                            f"{mlp_name}.{proj_name} neuron {i} row not zero!"
            else:  # down_proj
                if in_dim != mlp_dim:
                    log_once(
                        f"[FLAP][check][warn] {mlp_name}.{proj_name}: in_dim={in_dim} != mlp_dim={mlp_dim}, skip check."
                    )
                    continue
                for i, keep in enumerate(neuron_mask):
                    if keep <= 0.0:
                        assert torch.allclose(W[:, i], torch.zeros_like(W[:, i]), atol=atol), \
                            f"{mlp_name}.{proj_name} neuron {i} col not zero!"


def check_pruned_heads_zero(model, flap_mask: Dict[str, torch.Tensor], atol=1e-6):
    module_dict = dict(model.named_modules())
    for attn_name, head_mask in flap_mask.items():
        if attn_name not in module_dict:
            continue
        attn = module_dict[attn_name]
        num_heads = int(head_mask.numel())
        for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            if not hasattr(attn, proj_name):
                continue
            proj = getattr(attn, proj_name)
            W = proj.weight.data
            out_dim, in_dim = W.shape

            if proj_name in ["q_proj", "k_proj", "v_proj"]:
                head_dim = out_dim // num_heads
                for h, keep in enumerate(head_mask):
                    if keep <= 0.0:
                        s = h * head_dim
                        e = (h + 1) * head_dim
                        assert torch.allclose(W[s:e, :], torch.zeros_like(W[s:e, :]), atol=atol), \
                            f"{attn_name}.{proj_name} head {h} not zero!"
            else:
                head_dim = in_dim // num_heads
                for h, keep in enumerate(head_mask):
                    if keep <= 0.0:
                        s = h * head_dim
                        e = (h + 1) * head_dim
                        assert torch.allclose(W[:, s:e], torch.zeros_like(W[:, s:e]), atol=atol), \
                            f"{attn_name}.{proj_name} head {h} not zero!"

def apply_flap_mask_for_finetune(
    model,
    attn_head_masks: Dict[str, torch.Tensor],
    mlp_neuron_masks: Dict[str, torch.Tensor],
):
    """
    FLAP 宽度剪枝微调约束：

    1) Attention (head-level)：
       - attn_head_masks[name] 是长度为 num_heads 的 0/1 向量，name 对应 self_attn 模块；
       - 对应 self_attn 里的 q_proj / k_proj / v_proj / o_proj：
            * q/k/v: 每个 head 在 out_features 上占一段 [h*head_dim : (h+1)*head_dim]
                     mask[h] == 0 -> 这一段行置 0，并用 grad hook 冻结
            * o_proj: 每个 head 在 in_features 上占一段 [h*head_dim : (h+1)*head_dim]
                     mask[h] == 0 -> 这一段列置 0，并用 grad hook 冻结

    2) MLP (neuron-level)：
       - mlp_neuron_masks[name] 是长度为 mlp_dim 的 0/1 向量，name 对应 mlp 模块；
       - 对应 LLaMA-MLP 的 3 个线性层：
            * gate_proj: [mlp_dim, hidden_size]   -> 以行划分 neuron
            * up_proj:   [mlp_dim, hidden_size]   -> 同上
            * down_proj: [hidden_size, mlp_dim]   -> 以列划分 neuron
         mask[i] == 0 -> 第 i 个中间 neuron 的行/列置 0，并用 grad hook 冻结
    """
    module_dict = dict(model.named_modules())

    total_masked_linears = 0

    # ==================== 1) Attention head 剪枝 ====================
    for attn_name, head_mask in attn_head_masks.items():
        if attn_name not in module_dict:
            log_once(f"[FLAP][warn] attn module '{attn_name}' not found in model.named_modules(), skip.")
            continue

        attn = module_dict[attn_name]
        head_mask = head_mask.float()
        num_heads = int(head_mask.numel())

        log_once(f"[FLAP][finetune] Apply head mask on {attn_name}: num_heads(mask)={num_heads}")

        for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            if not hasattr(attn, proj_name):
                continue

            proj = getattr(attn, proj_name)
            W = proj.weight.data        # [out_dim, in_dim]
            out_dim, in_dim = W.shape

            if proj_name in ["q_proj", "k_proj", "v_proj"]:
                # head 划分在输出维度
                if out_dim % num_heads != 0:
                    raise ValueError(
                        f"[FLAP][finetune] {attn_name}.{proj_name}: "
                        f"out_dim={out_dim} not divisible by num_heads={num_heads}"
                    )
                head_dim = out_dim // num_heads
                mask_2d = torch.ones_like(W)
                for h, keep in enumerate(head_mask):
                    if keep <= 0.0:
                        s = h * head_dim
                        e = (h + 1) * head_dim
                        if e <= mask_2d.size(0):
                            mask_2d[s:e, :] = 0.0
            else:
                # o_proj: head 划分在输入维度
                if in_dim % num_heads != 0:
                    raise ValueError(
                        f"[FLAP][finetune] {attn_name}.{proj_name}: "
                        f"in_dim={in_dim} not divisible by num_heads={num_heads}"
                    )
                head_dim = in_dim // num_heads
                mask_2d = torch.ones_like(W)
                for h, keep in enumerate(head_mask):
                    if keep <= 0.0:
                        s = h * head_dim
                        e = (h + 1) * head_dim
                        if e <= mask_2d.size(1):
                            mask_2d[:, s:e] = 0.0

            log_once(
                f"[FLAP][finetune] Masking {attn_name}.{proj_name}: "
                f"weight_shape={tuple(W.shape)}, head_dim={head_dim}, "
                f"pruned_elems={(mask_2d == 0).sum().item()}"
            )

            register_param_mask_with_grad_hook(
                proj.weight,
                mask_2d,
                f"{attn_name}.{proj_name}"
            )
            total_masked_linears += 1

    # ==================== 2) MLP neuron 剪枝 ====================
    for mlp_name, neuron_mask in mlp_neuron_masks.items():
        if mlp_name not in module_dict:
            log_once(f"[FLAP][warn] mlp module '{mlp_name}' not found in model.named_modules(), skip.")
            continue

        mlp = module_dict[mlp_name]
        neuron_mask = neuron_mask.float()
        mlp_dim = int(neuron_mask.numel())

        log_once(f"[FLAP][finetune] Apply MLP neuron mask on {mlp_name}: mlp_dim(mask)={mlp_dim}")

        for proj_name in ["gate_proj", "up_proj", "down_proj"]:
            if not hasattr(mlp, proj_name):
                continue

            proj = getattr(mlp, proj_name)
            W = proj.weight.data    # [out_dim, in_dim]
            out_dim, in_dim = W.shape

            if proj_name in ["gate_proj", "up_proj"]:
                # 中间维度在输出维度
                if out_dim != mlp_dim:
                    raise ValueError(
                        f"[FLAP][finetune] {mlp_name}.{proj_name}: "
                        f"out_dim={out_dim} != mlp_dim={mlp_dim}"
                    )
                mask_2d = torch.ones_like(W)
                for i, keep in enumerate(neuron_mask):
                    if keep <= 0.0:
                        mask_2d[i, :] = 0.0
            else:  # down_proj
                # 中间维度在输入维度
                if in_dim != mlp_dim:
                    raise ValueError(
                        f"[FLAP][finetune] {mlp_name}.{proj_name}: "
                        f"in_dim={in_dim} != mlp_dim={mlp_dim}"
                    )
                mask_2d = torch.ones_like(W)
                for i, keep in enumerate(neuron_mask):
                    if keep <= 0.0:
                        mask_2d[:, i] = 0.0

            log_once(
                f"[FLAP][finetune] Masking {mlp_name}.{proj_name}: "
                f"weight_shape={tuple(W.shape)}, "
                f"pruned_elems={(mask_2d == 0).sum().item()}"
            )

            register_param_mask_with_grad_hook(
                proj.weight,
                mask_2d,
                f"{mlp_name}.{proj_name}"
            )
            total_masked_linears += 1

    log_once(f"[FLAP][finetune] All FLAP masks applied. "
             f"Total Linear weights with mask: {total_masked_linears}")

# ============================= 训练主逻辑 =============================

def parse_args():
    parser = argparse.ArgumentParser(description="Finetune FLAP-pruned HF LLM with DeepSpeed")

    # 基本参数
    parser.add_argument("--model_name_or_path", type=str, required=True,
                        help="剪枝后 HF 格式模型路径（内含 flap_mask.json）")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="最终模型与中间模型保存目录")

    # 训练超参
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--num_train_epochs", type=int, default=3)
    parser.add_argument("--max_steps", type=int, default=-1,
                        help=">0 时优先按 step 数训练；=-1 时按 epoch 训练完数据集")
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--warmup_ratio", type=float, default=0.03)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)

    # DeepSpeed
    parser.add_argument("--deepspeed_config", type=str, required=True,
                        help="DeepSpeed json 配置文件路径")

    # 日志 & 保存
    parser.add_argument("--logging_steps", type=int, default=50)
    parser.add_argument("--save_iter", type=int, default=-1,
                        help=">0 时每 N 个 global_step 保存一次；-1 表示不保存中间结果")

    # 训练数据相关
    parser.add_argument("--sft_dataset", type=str, default="mmlu",
                        help="选择用于 SFT 的数据集类型：'text' 为纯文本文件，'mmlu' 为 MMLU SFT。")
    parser.add_argument("--num_train_samples", type=int, default=None,
                        help="可选：限制使用的训练样本总数（仅对部分基准数据集有效，如 mmlu）。")

    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        deepspeed.init_distributed()


def set_seed(seed: int):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def build_dataloader(args, tokenizer):
    dataset = get_sft_dataset(
        name=args.sft_dataset,
        tokenizer=tokenizer,
        max_length=args.max_length,
        seed=args.seed,
        num_samples=args.num_train_samples,
    )

    world_size = dist.get_world_size()
    rank = dist.get_rank()

    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        drop_last=True,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=args.per_device_train_batch_size,
        sampler=sampler,
        collate_fn=collate_sft,
        num_workers=2,
        pin_memory=True,
    )

    return dataloader, len(dataset)


def load_model_and_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,   # 或 torch.float16
        device_map=None               # 让 DeepSpeed 接管
    )
    return model, tokenizer


def save_hf_checkpoint(model_engine, tokenizer, save_dir: str):
    os.makedirs(save_dir, exist_ok=True)
    model_to_save = model_engine.module if hasattr(model_engine, "module") else model_engine
    model_to_save.save_pretrained(save_dir, safe_serialization=True)
    tokenizer.save_pretrained(save_dir)


def train(args):
    setup_distributed()
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    is_main = (rank == 0)

    if is_main:
        os.makedirs(args.output_dir, exist_ok=True)
        log_once("====== Parsed args ======")
        for k, v in vars(args).items():
            log_once(f"{k}: {v}")
        log_once("=========================")

    set_seed(args.seed + rank)

    # ====== 1. 加载模型和 tokenizer ======
    if is_main:
        log_once(f"[Rank 0] Loading model from {args.model_name_or_path}")
    model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

    # ====== 1.1 读取 FLAP head mask，并应用到模型上 ======
    ckpt_dir = args.model_name_or_path
    attn_head_masks, mlp_neuron_masks, flap_meta = load_flap_masks(ckpt_dir)
    if is_main:
        log_once("[Rank 0] Detected FLAP flap_mask.json, applying FLAP width-only pruning constraints.")
        log_once(f"  - #self_attn modules with head mask: {len(attn_head_masks)}")
        log_once(f"  - #mlp modules with neuron mask:     {len(mlp_neuron_masks)}")

    apply_flap_mask_for_finetune(model, attn_head_masks, mlp_neuron_masks)

    if is_main:
        log_once("[Rank 0] Pruning constraints applied. Start building optimizer/DeepSpeed engine...")

    # ====== 2. 构造基础 AdamW 优化器 ======
    base_optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=getattr(args, "weight_decay", 0.0),
    )

    # ====== 3. 读取 DeepSpeed 配置，并修正 batch 相关字段 ======
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    per_device_bs = args.per_device_train_batch_size
    grad_accum = args.gradient_accumulation_steps

    ds_config["train_micro_batch_size_per_gpu"] = int(per_device_bs)
    ds_config["gradient_accumulation_steps"] = int(grad_accum)
    ds_config["train_batch_size"] = int(per_device_bs * grad_accum * world_size)

    ds_config["steps_per_print"] = ds_config.get("steps_per_print", 10_000_000)

    if "optimizer" in ds_config:
        if is_main:
            log_once("[Rank 0] Remove 'optimizer' from ds_config (use torch.AdamW instead)")
        ds_config.pop("optimizer")
    if "scheduler" in ds_config:
        if is_main:
            log_once("[Rank 0] Remove 'scheduler' from ds_config (use HF scheduler instead)")
        ds_config.pop("scheduler")

    # ====== 4. DataLoader & step 计算 ======
    train_dataloader, dataset_size = build_dataloader(args, tokenizer)

    steps_per_epoch = math.ceil(
        dataset_size / (args.per_device_train_batch_size * world_size * args.gradient_accumulation_steps)
    )
    if args.max_steps > 0:
        num_training_steps = args.max_steps
        num_train_epochs = math.ceil(num_training_steps / steps_per_epoch)
    else:
        num_train_epochs = args.num_train_epochs
        num_training_steps = steps_per_epoch * num_train_epochs

    if is_main:
        log_once(f"World size: {world_size}")
        log_once(f"Dataset size: {dataset_size}")
        log_once(f"Steps per epoch: {steps_per_epoch}")
        log_once(f"Train epochs: {num_train_epochs}")
        log_once(f"Total training steps: {num_training_steps}")

    # ====== 5. HF scheduler ======
    scheduler = create_scheduler(num_training_steps, base_optimizer, args.warmup_ratio)

    # ====== 6. 初始化 DeepSpeed ======
    if hasattr(args, "deepspeed"):
        args.deepspeed = None
    if hasattr(args, "deepspeed_config"):
        args.deepspeed_config = None

    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        optimizer=base_optimizer,
        args=args,
        lr_scheduler=scheduler,
        config_params=ds_config,
    )

    loss_fn = CrossEntropyLoss(ignore_index=-100)

    global_step = 0
    model_engine.train()

    if is_main:
        train_start_time = time.time()
    else:
        train_start_time = None

    for epoch in range(num_train_epochs):
        train_dataloader.sampler.set_epoch(epoch)

        for step, batch in enumerate(train_dataloader):
            if is_main:
                step_start_time = time.time()

            batch = {k: v.to(model_engine.local_rank) for k, v in batch.items()}

            outputs = model_engine(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                use_cache=False,
            )
            loss = outputs.loss

            model_engine.backward(loss)
            model_engine.step()

            global_step += 1

            if is_main and (global_step % args.logging_steps == 0):
                lr = lr_scheduler.get_last_lr()[0]
                step_time = time.time() - step_start_time
                elapsed = time.time() - train_start_time
                avg_step_time = elapsed / max(global_step, 1)
                remaining_steps = max(num_training_steps - global_step, 0)
                eta_seconds = remaining_steps * avg_step_time

                log_once(
                    f"[Epoch {epoch}] step {global_step}/{num_training_steps} | "
                    f"loss {loss.item():.4f} | lr {lr:.6e} | "
                    f"step_time {step_time:.3f}s | "
                    f"avg_step_time {avg_step_time:.3f}s | "
                    f"ETA {format_seconds(eta_seconds)}"
                )

                max_p = PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"]
                max_u = PRUNING_DEBUG_STATS["max_unpruned_grad"]
                num_p = PRUNING_DEBUG_STATS["num_pruned_grad_updates"]
                log_once(
                    f"[PruneDebug] max raw grad on pruned positions (before mask): {max_p:.4e} | "
                    f"max grad on unpruned positions: {max_u:.4e} | "
                    f"steps with pruned-grad flow: {num_p}"
                )
                log_once(
                    "             Note: after masking, grads on pruned positions are exactly 0 "
                    "(grad * mask), so这些位置不会被更新。"
                )

            if args.save_iter > 0 and (global_step % args.save_iter == 0):
                if is_main:
                    save_dir = os.path.join(args.output_dir, f"step-{global_step}")
                    log_once(f"[Rank 0] Saving checkpoint to {save_dir}")
                    save_hf_checkpoint(model_engine, tokenizer, save_dir)
                dist.barrier()

            if args.max_steps > 0 and global_step >= args.max_steps:
                break

        if args.max_steps > 0 and global_step >= args.max_steps:
            break
    
    check_pruned_heads_zero(model_engine.module, attn_head_masks)
    check_pruned_mlp_zero(model_engine.module, mlp_neuron_masks)
    
    if is_main:
        final_dir = os.path.join(args.output_dir, "final")
        total_elapsed = time.time() - train_start_time

        max_p = PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"]
        max_u = PRUNING_DEBUG_STATS["max_unpruned_grad"]
        num_p = PRUNING_DEBUG_STATS["num_pruned_grad_updates"]
        log_once("========== [PruneDebug Summary] ==========")
        log_once(f"Max raw grad on pruned positions (before mask): {max_p:.4e}")
        log_once(f"Max grad on unpruned positions:                {max_u:.4e}")
        log_once(f"Steps with pruned-grad flow (before mask):     {num_p}")
        log_once("All pruned positions use grad * 0 -> 0, so它们在整个训练过程中都没有被更新。")
        log_once("==========================================")

        log_once(f"[Rank 0] Training finished. Total time: {format_seconds(total_elapsed)}")
        log_once(f"[Rank 0] Saving final model to {final_dir}")
        save_hf_checkpoint(model_engine, tokenizer, final_dir)
    

    dist.barrier()


def main():
    args = parse_args()
    train(args)


if __name__ == "__main__":
    main()
