import os
import json
import math
import argparse
import time
from typing import Dict

import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn import CrossEntropyLoss

# from data_utils import get_sft_dataset, collate_sft
from data_utils_all_label import get_sft_dataset, collate_sft

import deepspeed
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

from dataclasses import dataclass
from typing import List, Dict

# =============== 剪枝状态 dataclass ===============

@dataclass
class PruningState:
    skipped_layers: List[int]
    pruned_heads: Dict[int, List[int]]
    pruned_neurons: Dict[int, List[int]]


# =============== 全局调试统计：确认剪掉部分梯度被置零 ===============

PRUNING_DEBUG_STATS = {
    "max_pruned_grad_before_mask": 0.0,   # 被剪位置在 mask 之前的原始梯度最大值
    "max_unpruned_grad": 0.0,             # 未剪位置的梯度最大值
    "num_pruned_grad_updates": 0,         # 有梯度流过被剪位置的 step 次数（mask 前）
}


# =============== LLaMA-style 取 decoder layers 的辅助函数 ===============

def get_decoder_layers(model):
    """
    兼容 LLaMA 等结构：优先尝试 model.model.layers，其次 model.model.encoder.layers
    """
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        return model.model.encoder.layers
    else:
        raise RuntimeError("[TALE] Cannot find decoder layers in model.")
    

# =============== depth 剪枝：把某一层 forward 改成恒等映射 ===============

def patch_layer_skip(layer):
    """
    将 decoder layer 的 forward 改成恒等映射：
    forward(hidden_states, ...) 直接返回 hidden_states。
    """
    original_forward = layer.forward

    def forward_skip(*args, **kwargs):
        if len(args) == 0:
            raise RuntimeError("[TALE] forward_skip got no positional args.")
        hidden_states = args[0]
        return hidden_states

    layer.forward = forward_skip
    return original_forward


def create_scheduler(num_training_steps, optimizer, warmup_ratio):
    num_warmup_steps = int(num_training_steps * warmup_ratio)
    return get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )


def collate_fn(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0)
    labels = torch.stack([b["labels"] for b in batch], dim=0)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


# ============================= 小工具：时间格式化 =============================

def format_seconds(seconds: float) -> str:
    """把秒数格式化成 Xd XXh XXm XXs"""
    seconds = int(seconds)
    days, rem = divmod(seconds, 86400)
    hours, rem = divmod(rem, 3600)
    minutes, seconds = divmod(rem, 60)
    return f"{days}d {hours:02d}h {minutes:02d}m {seconds:02d}s"


def load_pruning_state(ckpt_dir: str) -> PruningState:
    path = os.path.join(ckpt_dir, "pruning_state.json")
    if not os.path.exists(path):
        raise FileNotFoundError(f"[TALE] pruning_state.json not found in {ckpt_dir}")
    with open(path, "r", encoding="utf-8") as f:
        obj = json.load(f)
    return PruningState(
        skipped_layers=list(obj.get("skipped_layers", [])),
        pruned_heads={int(k): v for k, v in obj.get("pruned_heads", {}).items()},
        pruned_neurons={int(k): v for k, v in obj.get("pruned_neurons", {}).items()},
    )

def apply_pruning_for_finetune(model, pruning_state: PruningState):
    """
    在微调前调用：
    - 对被 depth 剪掉的层重新 patch forward（恒等映射），并关闭其梯度；
    - 对被 width 剪掉的 head/neuron 构造 weight mask + grad hook，
      确保这些位置永远为 0 且不更新。
    同时，记录一些调试统计量，训练过程中可打印出来确认：
      - 剪掉位置原始梯度有多大（mask 之前）
      - 未剪位置梯度有多大
      但最终用于更新的梯度 = grad * mask，剪掉位置一定是 0。
    """
    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    mlp_dim = model.config.intermediate_size
    num_heads = getattr(model.config, "num_attention_heads", None)
    if num_heads is None:
        attn_mod = layers[0].self_attn
        num_heads = getattr(attn_mod, "num_heads", None)
        if num_heads is None:
            raise RuntimeError("[TALE] Cannot infer num_heads in apply_pruning_for_finetune.")
    head_dim = hidden_size // num_heads

    # -------- 1) depth: 跳过的层，重新 patch 成 identity，并且不更新参数 --------
    for l in pruning_state.skipped_layers:
        if l < 0 or l >= num_layers:
            continue
        print(f"[TALE][finetune] Skipping layer {l} (depth-pruned).")
        layer = layers[l]
        # 重新 patch forward
        patch_layer_skip(layer)
        # 保险起见，把这一层所有参数 requires_grad 关掉
        for p in layer.parameters():
            p.requires_grad = False

    # 内部小工具：给 mask 注册 grad hook，并做调试统计
    def _register_mask_hook(tensor, mask_tensor, debug_name: str):
        """
        tensor: 需要被 mask 的参数（weight）
        mask_tensor: 同形状 0/1 mask（最初可能在 CPU）
        debug_name: 标识用的名字
        """
        # 先把权重乘一次 mask（注意要搬到当前 tensor.device）
        mask_on_param_dev = mask_tensor.to(tensor.device)
        tensor.data.mul_(mask_on_param_dev)

        def _hook(grad):
            # grad 是 mask 前的原始梯度
            if grad is None:
                return None

            # 把 mask 搬到 grad 所在的 device（DeepSpeed/ZeRO 会把 param shard 搬到各个 GPU）
            m = mask_tensor.to(grad.device)

            # 做调试统计
            pruned_region = grad[m == 0]
            unpruned_region = grad[m != 0]

            if pruned_region.numel() > 0:
                max_pruned = pruned_region.abs().max().item()
                if max_pruned > PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"]:
                    PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"] = max_pruned
                PRUNING_DEBUG_STATS["num_pruned_grad_updates"] += 1

            if unpruned_region.numel() > 0:
                max_unpruned = unpruned_region.abs().max().item()
                if max_unpruned > PRUNING_DEBUG_STATS["max_unpruned_grad"]:
                    PRUNING_DEBUG_STATS["max_unpruned_grad"] = max_unpruned

            # 真正用于参数更新的梯度 = grad * m
            return grad * m

        tensor.register_hook(_hook)

    # -------- 2) width: heads + neurons 的 mask + grad hook --------

    # 2.1 attention heads
    for l, heads in pruning_state.pruned_heads.items():
        if l < 0 or l >= num_layers:
            continue
        if len(heads) == 0:
            continue
        layer = layers[l]
        if not hasattr(layer, "self_attn"):
            continue
        attn = layer.self_attn
        print(f"[TALE][finetune] Layer {l}: masking attn heads {heads}")

        for name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
            if not hasattr(attn, name):
                continue
            proj = getattr(attn, name)
            W = proj.weight.data
            mask = torch.ones_like(W)

            if name in ["q_proj", "k_proj", "v_proj"]:
                # head 在 out_features 维度分段
                for h in heads:
                    if h < 0 or h >= num_heads:
                        continue
                    s = h * head_dim
                    e = (h + 1) * head_dim
                    if e <= mask.size(0):
                        mask[s:e, :] = 0
            else:
                # o_proj: head 在 in_features 方向
                for h in heads:
                    if h < 0 or h >= num_heads:
                        continue
                    s = h * head_dim
                    e = (h + 1) * head_dim
                    if e <= mask.size(1):
                        mask[:, s:e] = 0

            _register_mask_hook(proj.weight, mask, f"layer{l}.{name}")

    # 2.2 MLP neurons
    for l, neurons in pruning_state.pruned_neurons.items():
        if l < 0 or l >= num_layers:
            continue
        if len(neurons) == 0:
            continue
        layer = layers[l]
        if not hasattr(layer, "mlp"):
            continue
        mlp = layer.mlp
        print(f"[TALE][finetune] Layer {l}: masking mlp neurons {neurons}")

        # gate_proj / up_proj: 行维度是 neuron
        for name in ["gate_proj", "up_proj"]:
            if not hasattr(mlp, name):
                continue
            proj = getattr(mlp, name)
            W = proj.weight.data
            mask = torch.ones_like(W)
            for n in neurons:
                if 0 <= n < mask.size(0):
                    mask[n, :] = 0
            _register_mask_hook(proj.weight, mask, f"layer{l}.{name}")

        # down_proj: 列维度是 neuron
        if hasattr(mlp, "down_proj"):
            proj = mlp.down_proj
            W = proj.weight.data
            mask = torch.ones_like(W)
            for n in neurons:
                if 0 <= n < mask.size(1):
                    mask[:, n] = 0
            _register_mask_hook(proj.weight, mask, f"layer{l}.down_proj")

    print("[TALE][finetune] Pruning masks applied. Only unpruned weights will be updated.")


# ============================= 训练主逻辑 =============================

def parse_args():
    parser = argparse.ArgumentParser(description="Finetune pruned HF LLM with DeepSpeed")

    # 基本参数
    parser.add_argument("--model_name_or_path", type=str, required=True,
                        help="剪枝后 HF 格式模型路径")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="最终模型与中间模型保存目录")

    # 训练超参
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--num_train_epochs", type=int, default=3)
    parser.add_argument("--max_steps", type=int, default=-1,
                        help=">0 时优先按 step 数训练；=-1 时按 epoch 训练完数据集")
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--warmup_ratio", type=float, default=0.03)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)

    # DeepSpeed
    parser.add_argument("--deepspeed_config", type=str, required=True,
                        help="DeepSpeed json 配置文件路径")

    # 日志 & 保存
    parser.add_argument("--logging_steps", type=int, default=50)
    parser.add_argument("--save_iter", type=int, default=-1,
                        help=">0 时每 N 个 global_step 保存一次；-1 表示不保存中间结果")
    # 训练数据相关
    parser.add_argument("--sft_dataset", type=str, default="text",
                        choices=["text", "mmlu"],
                        help="选择用于 SFT 的数据集类型：'text' 为纯文本文件，'mmlu' 为 MMLU SFT。")

    parser.add_argument("--num_train_samples", type=int, default=None,
                        help="可选：限制使用的训练样本总数（仅对部分基准数据集有效，如 mmlu）。")

    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        deepspeed.init_distributed()


def set_seed(seed: int):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def build_dataloader(args, tokenizer):
    # 使用统一接口从 data_utils 加载 SFT Dataset
    dataset = get_sft_dataset(
        name=args.sft_dataset,
        tokenizer=tokenizer,
        max_length=args.max_length,
        seed=args.seed,
        num_samples=args.num_train_samples,
    )

    world_size = dist.get_world_size()
    rank = dist.get_rank()

    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        drop_last=True,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=args.per_device_train_batch_size,
        sampler=sampler,
        collate_fn=collate_sft,
        num_workers=2,
        pin_memory=True,
    )

    return dataloader, len(dataset)


def load_model_and_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,   # 或 torch.float16，看你硬件
        device_map=None               # 让 DeepSpeed 接管放到 GPU
    )
    return model, tokenizer


def save_hf_checkpoint(model_engine, tokenizer, save_dir: str):
    """以 HF 格式保存（支持 DeepSpeedEngine 包裹的模型）"""
    os.makedirs(save_dir, exist_ok=True)
    # deepspeed.Engine 有 .module
    model_to_save = model_engine.module if hasattr(model_engine, "module") else model_engine
    model_to_save.save_pretrained(save_dir, safe_serialization=True)
    tokenizer.save_pretrained(save_dir)


def train(args):
    # ====== 分布式初始化 ======
    setup_distributed()
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    is_main = (rank == 0)

    if is_main:
        os.makedirs(args.output_dir, exist_ok=True)
        print("====== Parsed args ======")
        for k, v in vars(args).items():
            print(f"{k}: {v}")
        print("=========================")

    set_seed(args.seed + rank)

    # ====== 1. 加载模型和 tokenizer ======
    if is_main:
        print(f"[Rank 0] Loading model from {args.model_name_or_path}")
    model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

    # ====== 1.1 读取剪枝状态，并应用到模型上 ======
    # 假设 args.model_name_or_path 就是你 TALE 导出的某个 tale_step_xxxx_s0.xxx 目录
    pruning_state = load_pruning_state(args.model_name_or_path)
    if is_main:
        print("[Rank 0] Loaded pruning_state:")
        print(f"  - skipped_layers: {pruning_state.skipped_layers}")
        print(f"  - #layers with pruned_heads: {len(pruning_state.pruned_heads)}")
        print(f"  - #layers with pruned_neurons: {len(pruning_state.pruned_neurons)}")

    apply_pruning_for_finetune(model, pruning_state)

    if is_main:
        print("[Rank 0] Pruning constraints applied. Start building optimizer/DeepSpeed engine...")


    # ====== 2. 构造 “基础” AdamW 优化器（torch.optim.AdamW） ======
    base_optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=getattr(args, "weight_decay", 0.0),
    )

    # ====== 3. 读取 DeepSpeed 配置，并修正 batch 相关字段 ======
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    # 这里沿用你自己的 batch 逻辑
    per_device_bs = args.per_device_train_batch_size
    grad_accum = args.gradient_accumulation_steps

    ds_config["train_micro_batch_size_per_gpu"] = int(per_device_bs)
    ds_config["gradient_accumulation_steps"] = int(grad_accum)
    ds_config["train_batch_size"] = int(per_device_bs * grad_accum * world_size)

    # 关掉 DeepSpeed 自己的频繁打印（只打印极少次）
    ds_config["steps_per_print"] = ds_config.get("steps_per_print", 10_000_000)

    # ⚠ 关键：把 ds_config 里的 optimizer / scheduler 去掉，
    # 避免 DeepSpeed 自己去建 FusedAdam/CPUAdam（会触发你那堆 GCC/nvcc 编译错误）
    if "optimizer" in ds_config:
        if is_main:
            print("[Rank 0] Remove 'optimizer' from ds_config (use torch.AdamW instead)")
        ds_config.pop("optimizer")
    if "scheduler" in ds_config:
        if is_main:
            print("[Rank 0] Remove 'scheduler' from ds_config (use HF scheduler instead)")
        ds_config.pop("scheduler")

    # ====== 4. 构建 DataLoader，并计算总 step 数 ======
    train_dataloader, dataset_size = build_dataloader(args, tokenizer)

    steps_per_epoch = math.ceil(
        dataset_size / (args.per_device_train_batch_size * world_size * args.gradient_accumulation_steps)
    )
    if args.max_steps > 0:
        num_training_steps = args.max_steps
        num_train_epochs = math.ceil(num_training_steps / steps_per_epoch)
    else:
        num_train_epochs = args.num_train_epochs
        num_training_steps = steps_per_epoch * num_train_epochs

    if is_main:
        print(f"World size: {world_size}")
        print(f"Dataset size: {dataset_size}")
        print(f"Steps per epoch: {steps_per_epoch}")
        print(f"Train epochs: {num_train_epochs}")
        print(f"Total training steps: {num_training_steps}")

    # ====== 5. 用 base_optimizer 先构建 HF 的 scheduler ======
    # 注意：这里传入的是 torch.optim.AdamW（合法），还没被 DeepSpeed 包装
    scheduler = create_scheduler(num_training_steps, base_optimizer, args.warmup_ratio)

    # ====== 6. 初始化 DeepSpeed ======
    # 防止 “同时从 args 和 initialize 里读 config” 的断言错误
    if hasattr(args, "deepspeed"):
        args.deepspeed = None
    if hasattr(args, "deepspeed_config"):
        args.deepspeed_config = None

    # 关键：像你贴的那份“能跑的代码”一样，把 optimizer 和 scheduler 一起传给 deepspeed
    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        optimizer=base_optimizer,      # torch.AdamW
        args=args,
        lr_scheduler=scheduler,        # HF scheduler
        config_params=ds_config,       # 注意这里用 config_params
    )

    # 后面统一用 DeepSpeed 返回的 lr_scheduler（其实就是包了一层，但行为和我们传进去的一致）
    loss_fn = CrossEntropyLoss(ignore_index=-100)

    global_step = 0
    model_engine.train()

    # 统计时间：总开始时间（用于 ETA），以及上一个 step 的结束时间
    if is_main:
        train_start_time = time.time()
    else:
        train_start_time = None  # 不用

    for epoch in range(num_train_epochs):
        # sampler epoch 设置
        train_dataloader.sampler.set_epoch(epoch)

        for step, batch in enumerate(train_dataloader):
            if is_main:
                step_start_time = time.time()

            # 将 batch 放到当前 device
            batch = {k: v.to(model_engine.local_rank) for k, v in batch.items()}

            outputs = model_engine(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                use_cache=False,
            )
            loss = outputs.loss

            model_engine.backward(loss)
            model_engine.step()
            lr_scheduler.step()   # ✔ 调用 DeepSpeed 返回的 scheduler

            global_step += 1

            # 日志（只在 rank 0 打）
            if is_main and (global_step % args.logging_steps == 0):
                lr = lr_scheduler.get_last_lr()[0]
                step_time = time.time() - step_start_time
                elapsed = time.time() - train_start_time
                avg_step_time = elapsed / max(global_step, 1)
                remaining_steps = max(num_training_steps - global_step, 0)
                eta_seconds = remaining_steps * avg_step_time

                print(
                    f"[Epoch {epoch}] step {global_step}/{num_training_steps} | "
                    f"loss {loss.item():.4f} | lr {lr:.6e} | "
                    f"step_time {step_time:.3f}s | "
                    f"avg_step_time {avg_step_time:.3f}s | "
                    f"ETA {format_seconds(eta_seconds)}"
                )

                # 额外打印一次剪枝梯度的调试统计
                max_p = PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"]
                max_u = PRUNING_DEBUG_STATS["max_unpruned_grad"]
                num_p = PRUNING_DEBUG_STATS["num_pruned_grad_updates"]
                print(
                    f"[PruneDebug] max raw grad on pruned positions (before mask): {max_p:.4e} | "
                    f"max grad on unpruned positions: {max_u:.4e} | "
                    f"steps with pruned-grad flow: {num_p}"
                )
                print(
                    "             Note: after masking, grads on pruned positions are exactly 0 "
                    "(grad * mask), so这些位置不会被更新。"
                )


            # 中间保存逻辑：save_iter > 0 才生效
            if args.save_iter > 0 and (global_step % args.save_iter == 0):
                if is_main:
                    save_dir = os.path.join(args.output_dir, f"step-{global_step}")
                    print(f"[Rank 0] Saving checkpoint to {save_dir}")
                    save_hf_checkpoint(model_engine, tokenizer, save_dir)
                dist.barrier()

            # 若按 max_steps 训练，到点就停
            if args.max_steps > 0 and global_step >= args.max_steps:
                break

        if args.max_steps > 0 and global_step >= args.max_steps:
            break

    # 训练结束后，保存最终模型
    if is_main:
        final_dir = os.path.join(args.output_dir, "final")
        total_elapsed = time.time() - train_start_time
        # 总结一次剪枝梯度调试结果
        max_p = PRUNING_DEBUG_STATS["max_pruned_grad_before_mask"]
        max_u = PRUNING_DEBUG_STATS["max_unpruned_grad"]
        num_p = PRUNING_DEBUG_STATS["num_pruned_grad_updates"]
        print("========== [PruneDebug Summary] ==========")
        print(f"Max raw grad on pruned positions (before mask): {max_p:.4e}")
        print(f"Max grad on unpruned positions:                {max_u:.4e}")
        print(f"Steps with pruned-grad flow (before mask):     {num_p}")
        print("All pruned positions use grad * 0 -> 0, so它们在整个训练过程中都没有被更新。")
        print("==========================================")

        print(f"[Rank 0] Training finished. Total time: {format_seconds(total_elapsed)}")
        print(f"[Rank 0] Saving final model to {final_dir}")
        save_hf_checkpoint(model_engine, tokenizer, final_dir)

    dist.barrier()


def main():
    args = parse_args()
    train(args)


if __name__ == "__main__":
    main()
