import os
import json
import math
import argparse
import time
from typing import Dict

import torch
import torch.distributed as dist
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torch.nn import CrossEntropyLoss

# from data_utils import get_sft_dataset, collate_sft
from data_utils_all_label import get_sft_dataset, collate_sft

import deepspeed
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)


def create_scheduler(num_training_steps, optimizer, warmup_ratio):
    num_warmup_steps = int(num_training_steps * warmup_ratio)
    return get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
    )


def collate_fn(batch):
    input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
    attention_mask = torch.stack([b["attention_mask"] for b in batch], dim=0)
    labels = torch.stack([b["labels"] for b in batch], dim=0)
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }


# ============================= 小工具：时间格式化 =============================

def format_seconds(seconds: float) -> str:
    """把秒数格式化成 Xd XXh XXm XXs"""
    seconds = int(seconds)
    days, rem = divmod(seconds, 86400)
    hours, rem = divmod(rem, 3600)
    minutes, seconds = divmod(rem, 60)
    return f"{days}d {hours:02d}h {minutes:02d}m {seconds:02d}s"

# ============================= 训练主逻辑 =============================

def parse_args():
    parser = argparse.ArgumentParser(description="Finetune pruned HF LLM with DeepSpeed")

    # 基本参数
    parser.add_argument("--model_name_or_path", type=str, required=True,
                        help="剪枝后 HF 格式模型路径")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="最终模型与中间模型保存目录")

    # 训练超参
    parser.add_argument("--max_length", type=int, default=512)
    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--num_train_epochs", type=int, default=3)
    parser.add_argument("--max_steps", type=int, default=-1,
                        help=">0 时优先按 step 数训练；=-1 时按 epoch 训练完数据集")
    parser.add_argument("--learning_rate", type=float, default=5e-5)
    parser.add_argument("--weight_decay", type=float, default=0.01)
    parser.add_argument("--warmup_ratio", type=float, default=0.03)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
    parser.add_argument("--seed", type=int, default=42)

    # DeepSpeed
    parser.add_argument("--deepspeed_config", type=str, required=True,
                        help="DeepSpeed json 配置文件路径")

    # 日志 & 保存
    parser.add_argument("--logging_steps", type=int, default=50)
    parser.add_argument("--save_iter", type=int, default=-1,
                        help=">0 时每 N 个 global_step 保存一次；-1 表示不保存中间结果")
    # 训练数据相关
    parser.add_argument("--sft_dataset", type=str, default="mmlu",
                        help="选择用于 SFT 的数据集类型：'text' 为纯文本文件，'mmlu' 为 MMLU SFT。")

    parser.add_argument("--num_train_samples", type=int, default=None,
                        help="可选：限制使用的训练样本总数（仅对部分基准数据集有效，如 mmlu）。")

    return parser.parse_args()


def setup_distributed():
    if not dist.is_initialized():
        deepspeed.init_distributed()


def set_seed(seed: int):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def build_dataloader(args, tokenizer):
    # 使用统一接口从 data_utils 加载 SFT Dataset
    dataset = get_sft_dataset(
        name=args.sft_dataset,
        tokenizer=tokenizer,
        max_length=args.max_length,
        seed=args.seed,
        num_samples=args.num_train_samples,
        split="train"
    )

    world_size = dist.get_world_size()
    rank = dist.get_rank()

    sampler = DistributedSampler(
        dataset,
        num_replicas=world_size,
        rank=rank,
        shuffle=True,
        drop_last=True,
    )

    dataloader = DataLoader(
        dataset,
        batch_size=args.per_device_train_batch_size,
        sampler=sampler,
        collate_fn=collate_sft,
        num_workers=2,
        pin_memory=True,
    )

    return dataloader, len(dataset)


def load_model_and_tokenizer(model_path: str):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.bfloat16,   # 或 torch.float16，看你硬件
        device_map=None               # 让 DeepSpeed 接管放到 GPU
    )
    return model, tokenizer


def save_hf_checkpoint(model_engine, tokenizer, save_dir: str):
    """以 HF 格式保存（支持 DeepSpeedEngine 包裹的模型）"""
    os.makedirs(save_dir, exist_ok=True)
    # deepspeed.Engine 有 .module
    model_to_save = model_engine.module if hasattr(model_engine, "module") else model_engine
    model_to_save.save_pretrained(save_dir, safe_serialization=True)
    tokenizer.save_pretrained(save_dir)


def train(args):
    # ====== 分布式初始化 ======
    setup_distributed()
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    is_main = (rank == 0)

    if is_main:
        os.makedirs(args.output_dir, exist_ok=True)
        print("====== Parsed args ======")
        for k, v in vars(args).items():
            print(f"{k}: {v}")
        print("=========================")

    set_seed(args.seed + rank)

    # ====== 1. 加载模型和 tokenizer ======
    if is_main:
        print(f"[Rank 0] Loading model from {args.model_name_or_path}")
    model, tokenizer = load_model_and_tokenizer(args.model_name_or_path)

    # ====== 2. 构造 “基础” AdamW 优化器（torch.optim.AdamW） ======
    base_optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=getattr(args, "weight_decay", 0.0),
    )

    # ====== 3. 读取 DeepSpeed 配置，并修正 batch 相关字段 ======
    with open(args.deepspeed_config, "r") as f:
        ds_config = json.load(f)

    # 这里沿用你自己的 batch 逻辑
    per_device_bs = args.per_device_train_batch_size
    grad_accum = args.gradient_accumulation_steps

    ds_config["train_micro_batch_size_per_gpu"] = int(per_device_bs)
    ds_config["gradient_accumulation_steps"] = int(grad_accum)
    ds_config["train_batch_size"] = int(per_device_bs * grad_accum * world_size)

    # 关掉 DeepSpeed 自己的频繁打印（只打印极少次）
    ds_config["steps_per_print"] = ds_config.get("steps_per_print", 10_000_000)

    # ⚠ 关键：把 ds_config 里的 optimizer / scheduler 去掉，
    # 避免 DeepSpeed 自己去建 FusedAdam/CPUAdam（会触发你那堆 GCC/nvcc 编译错误）
    if "optimizer" in ds_config:
        if is_main:
            print("[Rank 0] Remove 'optimizer' from ds_config (use torch.AdamW instead)")
        ds_config.pop("optimizer")
    if "scheduler" in ds_config:
        if is_main:
            print("[Rank 0] Remove 'scheduler' from ds_config (use HF scheduler instead)")
        ds_config.pop("scheduler")

    # ====== 4. 构建 DataLoader，并计算总 step 数 ======
    train_dataloader, dataset_size = build_dataloader(args, tokenizer)

    steps_per_epoch = math.ceil(
        dataset_size / (args.per_device_train_batch_size * world_size * args.gradient_accumulation_steps)
    )
    if args.max_steps > 0:
        num_training_steps = args.max_steps
        num_train_epochs = math.ceil(num_training_steps / steps_per_epoch)
    else:
        num_train_epochs = args.num_train_epochs
        num_training_steps = steps_per_epoch * num_train_epochs

    if is_main:
        print(f"World size: {world_size}")
        print(f"Dataset size: {dataset_size}")
        print(f"Steps per epoch: {steps_per_epoch}")
        print(f"Train epochs: {num_train_epochs}")
        print(f"Total training steps: {num_training_steps}")

    # ====== 5. 用 base_optimizer 先构建 HF 的 scheduler ======
    # 注意：这里传入的是 torch.optim.AdamW（合法），还没被 DeepSpeed 包装
    scheduler = create_scheduler(num_training_steps, base_optimizer, args.warmup_ratio)

    # ====== 6. 初始化 DeepSpeed ======
    # 防止 “同时从 args 和 initialize 里读 config” 的断言错误
    if hasattr(args, "deepspeed"):
        args.deepspeed = None
    if hasattr(args, "deepspeed_config"):
        args.deepspeed_config = None

    # 关键：像你贴的那份“能跑的代码”一样，把 optimizer 和 scheduler 一起传给 deepspeed
    model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        model_parameters=model.parameters(),
        optimizer=base_optimizer,      # torch.AdamW
        args=args,
        lr_scheduler=scheduler,        # HF scheduler
        config_params=ds_config,       # 注意这里用 config_params
    )

    # 后面统一用 DeepSpeed 返回的 lr_scheduler（其实就是包了一层，但行为和我们传进去的一致）
    loss_fn = CrossEntropyLoss(ignore_index=-100)

    global_step = 0
    model_engine.train()

    # 统计时间：总开始时间（用于 ETA），以及上一个 step 的结束时间
    if is_main:
        train_start_time = time.time()
    else:
        train_start_time = None  # 不用

    for epoch in range(num_train_epochs):
        # sampler epoch 设置
        train_dataloader.sampler.set_epoch(epoch)

        for step, batch in enumerate(train_dataloader):
            if is_main:
                step_start_time = time.time()

            # 将 batch 放到当前 device
            batch = {k: v.to(model_engine.local_rank) for k, v in batch.items()}

            outputs = model_engine(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                use_cache=False,
            )
            loss = outputs.loss

            model_engine.backward(loss)
            model_engine.step()
            # lr_scheduler.step()   # ✔ 调用 DeepSpeed 返回的 scheduler

            global_step += 1

            # 日志（只在 rank 0 打）
            if is_main and (global_step % args.logging_steps == 0):
                lr = lr_scheduler.get_last_lr()[0]
                step_time = time.time() - step_start_time
                elapsed = time.time() - train_start_time
                avg_step_time = elapsed / max(global_step, 1)
                remaining_steps = max(num_training_steps - global_step, 0)
                eta_seconds = remaining_steps * avg_step_time

                print(
                    f"[Epoch {epoch}] step {global_step}/{num_training_steps} | "
                    f"loss {loss.item():.4f} | lr {lr:.6e} | "
                    f"step_time {step_time:.3f}s | "
                    f"avg_step_time {avg_step_time:.3f}s | "
                    f"ETA {format_seconds(eta_seconds)}"
                )

            # 中间保存逻辑：save_iter > 0 才生效
            if args.save_iter > 0 and (global_step % args.save_iter == 0):
                if is_main:
                    save_dir = os.path.join(args.output_dir, f"step-{global_step}")
                    print(f"[Rank 0] Saving checkpoint to {save_dir}")
                    save_hf_checkpoint(model_engine, tokenizer, save_dir)
                dist.barrier()

            # 若按 max_steps 训练，到点就停
            if args.max_steps > 0 and global_step >= args.max_steps:
                break

        if args.max_steps > 0 and global_step >= args.max_steps:
            break

    # 训练结束后，保存最终模型
    if is_main:
        final_dir = os.path.join(args.output_dir, "final")
        total_elapsed = time.time() - train_start_time
        print(f"[Rank 0] Training finished. Total time: {format_seconds(total_elapsed)}")
        print(f"[Rank 0] Saving final model to {final_dir}")
        save_hf_checkpoint(model_engine, tokenizer, final_dir)
    dist.barrier()


def main():
    args = parse_args()
    train(args)


if __name__ == "__main__":
    main()
