#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import time
import argparse
import logging

import torch
from utils.data_utils import get_trainloaders
from utils.util import *   # 你原工程依赖的 logging / ModelHander 构造通常也在这里


def _resolve_manual_layers(manual_layers, num_layers, layers_are_1based: bool):
    """
    将用户输入层号解析为 0-based，并做合法性校验。
    remove_layers 会改变索引，所以必须按从大到小删，避免 index 漂移。
    """
    layers = [int(x) for x in manual_layers]
    if layers_are_1based:
        layers = [x - 1 for x in layers]

    for x in layers:
        if x < 0 or x >= num_layers:
            raise ValueError(f"Invalid layer idx={x} for current num_layers={num_layers}")

    # 去重 + 从大到小
    layers = sorted(set(layers), reverse=True)
    return layers


def main_func(args, modelhander):
    """
    1) 删除指定层
    2) HF 格式保存删层后的模型，并写 info.json
    """
    # 1) 解析要删除的层
    cur_L = int(modelhander.config.num_hidden_layers)
    manual_layers_desc = _resolve_manual_layers(
        manual_layers=args.manual_remove_layers,
        num_layers=cur_L,
        layers_are_1based=args.manual_layers_are_1based
    )
    logging.info(f"Current num_hidden_layers = {cur_L}")
    logging.info(f"Manual remove layers (0-based, desc) = {manual_layers_desc}")

    info = {
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
        "calibration_dataset": args.calibration_dataset,
        "nsamples": int(args.nsamples),
        "seed": int(args.seed),
        "seqlen": int(modelhander.model.seqlen),
        "num_layers_before": cur_L,
        "manual_remove_layers_input": [int(x) for x in args.manual_remove_layers],
        "manual_layers_are_1based": bool(args.manual_layers_are_1based),
        "resolved_remove_layers_0based_desc": manual_layers_desc,
    }

    # 2) 逐层删除（从大到小，避免索引漂移）
    removed = []
    for idx_to_remove in manual_layers_desc:
        logging.info(f"Removing layer idx={idx_to_remove} (current model space, 0-based)")
        # 你原代码这里传 int；我们保持一致
        modelhander.remove_layers(removal_list=idx_to_remove)
        removed.append(idx_to_remove)

    new_L = int(modelhander.config.num_hidden_layers)
    info["num_layers_after"] = new_L
    info["removed_layers_0based_desc"] = removed

    # 3) 保存为 HF 格式
    #    你原代码就是 modelhander.save(path=save_path)，通常内部是 save_pretrained + tokenizer.save_pretrained
    os.makedirs(args.save_path, exist_ok=True)
    save_dir = os.path.join(args.save_path, f"{args.save_name}_manualrm_{new_L}")
    modelhander.save(path=save_dir)
    logging.info(f"[Saved] HF model saved to: {save_dir}")
    info["save_dir"] = save_dir

    # 4) 写 info.json（避免你之后忘了删了哪些层）
    info_path = os.path.join(save_dir, "manual_remove_info.json")
    with open(info_path, "w", encoding="utf-8") as f:
        json.dump(info, f, indent=2, ensure_ascii=False)
    logging.info(f"[Saved] Info json saved to: {info_path}")


def build_parser():
    ap = argparse.ArgumentParser("Manual remove specific layers and save HF model")

    # ===== 参考你现有方式的数据读取参数 =====
    ap.add_argument("--calibration_dataset", type=str, required=True)
    ap.add_argument("--nsamples", type=int, default=128)
    ap.add_argument("--seed", type=int, default=42)

    # ===== 仅为本功能新增的参数 =====
    ap.add_argument("--manual_remove_layers", type=int, nargs="+", required=True,
                    help="Layer indices to remove. Example: --manual_remove_layers 18 19")
    ap.add_argument("--manual_layers_are_1based", action="store_true",
                    help="If set, treat manual_remove_layers as 1-based indices.")

    ap.add_argument("--save_path", type=str, required=True)
    ap.add_argument("--save_name", type=str, required=True)

    # ===== 下面两项通常是你工程里构造 modelhander 必需的（按你原工程实际来）=====
    # 如果你现有工程的 modelhander 构造不需要这两个参数，可以删掉/忽略。
    ap.add_argument("--model_name_or_path", type=str, required=True,
                    help="HF model path to load before pruning (dense or already pruned).")
    ap.add_argument("--dtype", type=str, default="fp16", choices=["fp16", "fp32"],
                    help="Model dtype for loading.")

    return ap


def main():
    logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s")
    args = build_parser().parse_args()

    # ============================================================
    # 你工程里 modelhander 的构造方式各人略有不同，但通常在 utils.util 里已有封装。
    # 下面给一个“最常见”的写法，你需要按你工程实际的 ModelHander 类名/入参微调。
    # 关键要求：modelhander 必须提供：
    #   - modelhander.model / modelhander.tokenizer / modelhander.config
    #   - modelhander.remove_layers(removal_list=...)
    #   - modelhander.save(path=...)
    # ============================================================

    torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32

    # 下面这行是“占位式”的：请把 ModelHander 替换成你工程里实际的 handler 类名/构造函数
    # 例如：modelhander = ModelHandler(args.model_name_or_path, torch_dtype=torch_dtype, device="cuda")
    try:
        from utils.model_utils import get_llmhander
        modelhander= get_llmhander(args.model_name_or_path, concat_merge=False)
    except NameError as e:
        raise RuntimeError(
            "Cannot find `ModelHander` in utils.util. "
            "Please replace `ModelHander(...)` with your actual handler constructor.\n"
            "Your handler must support: remove_layers(removal_list=...), save(path=...)"
        ) from e

    main_func(args, modelhander)


if __name__ == "__main__":
    main()
