#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
TALE: Task-Aware Layer+width iterative pruning (v0.1)

本脚本做的事情：
- 加载一个 HF CausalLM 模型（假定 LLaMA 系结构：model.model.layers[i].self_attn / mlp）
- 用你已有的 FLAP 逻辑计算一次 attn/mlp 的通道重要性（WIFV / IFV / WIFN）
- 在当前模型上迭代进行 depth + width 剪枝：
    - 每一轮构造候选动作：
        - depth: 删除 1 层或 2 层（前期），删除 1 层（后期）
        - width: 在若干候选 layer 中，对 attn / mlp 剪一小部分通道
    - 对每个候选动作，用 ΔLoss / ΔParams 评分（ΔLoss 在校准集上真实 forward）
    - 选择 score 最小的动作，真正修改模型参数
    - 记录并保存每个 sparsity 阶段的模型 checkpoint（dense model family）
- 当：
    - ΔLoss/ΔParams 相比早期 baseline 连续若干轮明显变大 → 进入 "fine" 小步剪枝阶段
    - depth 或 width 一侧在最近若干轮都比另一侧 "代价更高" → 冻结该侧，只剪另一侧

⚠ 重要假设：
- 模型是 LLaMA 风格，decoder 层位于 model.model.layers
- 每层含：
    - layer.self_attn.q_proj/k_proj/v_proj/o_proj (nn.Linear, 无 bias 或 bias 较简单)
    - layer.mlp.gate_proj / up_proj / down_proj
- 你当前环境已经安装 bitsandbytes / datasets / transformers

你可以按需把这个脚本拆成多个 py 文件，这里先写成一个方便你整体拷贝。
"""

import os
import sys
import re
import argparse
import random
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from datasets import load_dataset, Dataset
from tqdm import tqdm

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)

from dataclasses import dataclass
import json


import bitsandbytes as bnb
from peft.tuners.lora.bnb import Linear4bit, Linear8bitLt
from peft.utils.integrations import dequantize_bnb_weight


# ============================================================
# ===================== 数据加载部分 ==========================
# ============================================================

class TokenizerWrapper:
    def __init__(self, input_ids: torch.Tensor):
        self.input_ids = input_ids


def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def get_c4(nsamples, seed, seqlen, tokenizer, data_dir=None):
    if data_dir is None:
        traindata = load_dataset('allenai/c4', 'allenai--c4', split='train')
        valdata = load_dataset('allenai/c4', 'allenai--c4', split='validation')
    else:
        train_files = {'train': 'en/c4-train.00000-of-01024.json.gz'}
        val_files = {'validation': 'en/c4-validation.00000-of-00008.json.gz'}
        traindata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=train_files,
            split='train'
        )
        valdata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=val_files,
            split='validation'
        )

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)
    return trainloader, valenc


def format_mmlu_example(dataset, include_answer: bool = True):
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data


def get_mmlu_trainenc(
    data_dir: str,
    seed: int,
    nsamples: int,
    tokenizer,
    seqlen: int = 512,
    num_tasks: int = None,
):
    from tqdm import tqdm as tqdm_local
    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]
    random.seed(seed)
    if num_tasks is not None:
        random.shuffle(subclass)
        subclass = subclass[:num_tasks]
    keys = ["question", "subject", "choices", "answer"]
    subnum = nsamples // len(subclass)
    extra = nsamples % len(subclass)
    num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
    random.shuffle(num_list)

    traindata = {key: [] for key in keys}
    mmlu_path = os.path.join(data_dir, "mmlu")

    for num, classsname in tqdm_local(
        zip(num_list, subclass),
        total=len(subclass),
        desc="Loading the subclass in MMLU"
    ):
        if num == 0:
            continue
        try:
            data = load_dataset(mmlu_path, classsname, split="train").shuffle(seed=seed)[:num]
        except Exception:
            data = load_dataset(mmlu_path, classsname, split="validation").shuffle(seed=seed)[:num]
        for key in keys:
            traindata[key].extend(data[key])

    dataset = Dataset.from_dict(traindata)
    all_data = format_mmlu_example(dataset, include_answer=False)

    trainenc = tokenizer(
        all_data,
        return_tensors='pt',
        max_length=seqlen,
        padding='max_length',
        truncation=True,
    )
    trainenc["input_ids"] = trainenc["input_ids"].reshape(1, -1)
    trainenc["attention_mask"] = trainenc["attention_mask"].reshape(1, -1)
    return trainenc


def get_mmlu_loader(
    data_dir: str,
    nsamples: int,
    seed: int,
    seqlen: int,
    tokenizer,
):
    enc = get_mmlu_trainenc(
        data_dir=data_dir,
        seed=seed,
        nsamples=nsamples,
        tokenizer=tokenizer,
        seqlen=seqlen,
    )
    ids = enc["input_ids"]
    total_len = ids.shape[1]
    random.seed(seed)
    loader = []
    for _ in range(nsamples):
        if total_len <= seqlen + 1:
            i = 0
        else:
            i = random.randint(0, total_len - seqlen - 1)
        j = i + seqlen
        inp = ids[:, i:j].clone()
        tar = inp.clone()
        tar[:, :-1] = -100
        loader.append((inp, tar))
    return loader, TokenizerWrapper(ids[:, :seqlen])


def get_loaders(
    name: str,
    nsamples: int = 128,
    seed: int = 0,
    seqlen: int = 2048,
    tokenizer=None,
    data_dir: str = None,
):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if 'c4' in name:
        return get_c4(nsamples, seed, seqlen, tokenizer, data_dir=data_dir)
    if 'mmlu' in name:
        if data_dir is None:
            raise ValueError("Using mmlu as calib dataset requires --data_dir pointing to parent of 'mmlu'.")
        return get_mmlu_loader(data_dir, nsamples, seed, seqlen, tokenizer)
    raise ValueError(f"Unknown dataset name in get_loaders: {name}")


# ============================================================
# ============== WrappedGPT & BiasGPT（FLAP） =================
# ============================================================

class WrappedGPT:
    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.rows = layer.out_features
        self.columns = layer.in_features

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp

        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples


class BiasGPT:
    def __init__(self, layer, metric):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.out_dim = layer.out_features
        self.in_dim = layer.in_features

        self.type = metric
        self.nsamples = 0

        self.baseline_inp = torch.zeros((self.in_dim), device=self.dev)
        if self.type == "WIFN":
            self.scaler_inp = torch.zeros((self.in_dim), device=self.dev)
        else:
            self.fluc_inp = torch.zeros((self.in_dim), device=self.dev)

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        batch_size = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        old_baseline_inp = self.baseline_inp
        self.baseline_inp *= self.nsamples / (self.nsamples + batch_size)
        self.baseline_inp += torch.mean(inp, dim=1) / (self.nsamples + batch_size)

        if self.type == "WIFN":
            inp32 = inp.type(torch.float32)
            self.scaler_inp *= self.nsamples / (self.nsamples + batch_size)
            self.scaler_inp += torch.norm(inp32, p=2, dim=1) ** 2 / (self.nsamples + batch_size)
        else:
            if self.nsamples == 0:
                self.fluc_inp = torch.zeros_like(self.baseline_inp)
            else:
                self.fluc_inp *= (self.nsamples - 1) / (self.nsamples + batch_size - 1)
                self.fluc_inp += torch.sum(
                    (inp - self.baseline_inp.unsqueeze(1)) *
                    (inp - old_baseline_inp.unsqueeze(1)),
                    dim=1
                ) / (self.nsamples + batch_size)

        self.nsamples += batch_size

    def free(self):
        self.baseline_inp = None
        if hasattr(self, 'fluc_inp'):
            self.fluc_inp = None
        if hasattr(self, 'scaler_inp'):
            self.scaler_inp = None
        torch.cuda.empty_cache()


# ============================================================
# ===================== 其他工具函数 ==========================
# ============================================================

def save_hf_checkpoint(model, tokenizer, output_dir, iter_idx, current_sparsity):
    step_dir = os.path.join(
        output_dir,
        f"tale_step_{iter_idx:04d}_s{current_sparsity:.3f}"
    )
    os.makedirs(step_dir, exist_ok=True)

    print(f"[TALE] Saving HF-style checkpoint to {step_dir}")
    model.save_pretrained(step_dir)
    tokenizer.save_pretrained(step_dir)
    return step_dir


@dataclass
class PruningState:
    skipped_layers: List[int]
    pruned_heads: Dict[int, List[int]]
    pruned_neurons: Dict[int, List[int]]


def save_pruning_state(
    ckpt_dir: str,
    active_layers: List[bool],
    width_state: "WidthState",
):
    """
    将当前 depth + width 剪枝状态保存为 JSON，放在对应 ckpt 目录下。
    """
    num_layers = len(active_layers)
    skipped_layers = [i for i, active in enumerate(active_layers) if not active]

    pruned_heads = {
        int(i): sorted(list(s))
        for i, s in enumerate(width_state.pruned_heads) if len(s) > 0
    }
    pruned_neurons = {
        int(i): sorted(list(s))
        for i, s in enumerate(width_state.pruned_neurons) if len(s) > 0
    }

    state = PruningState(
        skipped_layers=skipped_layers,
        pruned_heads=pruned_heads,
        pruned_neurons=pruned_neurons,
    )
    obj = {
        "skipped_layers": state.skipped_layers,
        "pruned_heads": state.pruned_heads,
        "pruned_neurons": state.pruned_neurons,
    }

    path = os.path.join(ckpt_dir, "pruning_state.json")
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2)
    print(f"[TALE] Saved pruning state to {path}")



def find_layers(
    module,
    layers=(nn.Linear, Linear4bit, Linear8bitLt, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt),
    name: str = ''
):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        if 'lora' in name:
            continue
        res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
    return res


def dequantize(layer):
    if hasattr(layer, "get_base_layer"):
        weight = layer.get_base_layer().weight
    else:
        weight = layer.weight
    if hasattr(weight, "quant_state") and weight.quant_state is not None:
        dequant_weight = dequantize_bnb_weight(weight, state=weight.quant_state)
        return dequant_weight
    else:
        return weight


def count_model_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


# ============================================================
# ============= FLAP-style importance 计算 ====================
# ============================================================

@torch.no_grad()
def compute_flap_importance(
    model,
    tokenizer,
    device,
    calib_dataset: str,
    data_dir: Optional[str],
    num_calib_sample: int,
    seqlen: int,
    metrics: str = "WIFV",
    wanda_sp: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
    """
    等价于你原来 prune_flap(..., only_importance=True) 的上半部分，
    返回：
        attn_metric: [L, H_attn_input] （未标准化）
        mlp_metric:  [L, H_mlp_input]
        attn_baseline_list: len=L，每个是 [H_attn_input]
        mlp_baseline_list:  len=L，每个是 [H_mlp_input]
    """
    use_cache = getattr(model.config, "use_cache", False)
    model.config.use_cache = False
    model.eval()

    print(f"[TALE] Loading calibration data for FLAP from {calib_dataset}")
    dataloader, _ = get_loaders(
        calib_dataset,
        nsamples=num_calib_sample,
        seed=42,
        tokenizer=tokenizer,
        seqlen=seqlen,
        data_dir=data_dir,
    )
    print(f"[TALE] Calibration dataset ready, #batches={len(dataloader)}")

    all_linears: Dict[str, nn.Module] = find_layers(model)
    target_names = [
        name for name in all_linears.keys()
        if name.endswith("self_attn.o_proj") or name.endswith("mlp.down_proj")
    ]
    if len(target_names) == 0:
        raise RuntimeError("[TALE] Cannot find any 'self_attn.o_proj' or 'mlp.down_proj' linear layers.")

    wrapped_layers: Dict[str, BiasGPT] = {}
    for name in target_names:
        if wanda_sp:
            wrapped_layers[name] = WrappedGPT(all_linears[name])
        else:
            wrapped_layers[name] = BiasGPT(all_linears[name], metrics)

    handles = []

    def make_hook(name):
        def hook(module, inp, out):
            x_in = inp[0] if isinstance(inp, (tuple, list)) else inp
            wrapped_layers[name].add_batch(x_in.detach(), out.detach())
        return hook

    for name in target_names:
        handles.append(all_linears[name].register_forward_hook(make_hook(name)))

    print("[TALE] Running calibration forward passes for FLAP importance ...")
    for idx, (inp, _) in enumerate(tqdm(dataloader, desc="FLAP calib")):
        if idx >= num_calib_sample:
            break
        model(inp.to(device))

    for h in handles:
        h.remove()

    # 取得 decoder layers 引用（假定 LLaMA 风格）
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        layers = model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        layers = model.model.encoder.layers
    else:
        raise RuntimeError("[TALE] Cannot find model layers (expected model.model.layers or .model.encoder.layers).")
    num_layers = len(layers)

    def metric_IFV(module_name: str):
        return wrapped_layers[module_name].fluc_inp

    def metric_WIFV(module_name: str):
        layer_mod = all_linears[module_name]
        return wrapped_layers[module_name].fluc_inp * torch.sum(dequantize(layer_mod).data.pow(2), dim=0)

    def metric_WIFN(module_name: str):
        layer_mod = all_linears[module_name]
        return (
            torch.abs(dequantize(layer_mod).data) *
            torch.sqrt(wrapped_layers[module_name].scaler_inp.reshape((1, -1)))
        ).mean(axis=0)

    metrics_dict = {
        "IFV": metric_IFV,
        "WIFV": metric_WIFV,
        "WIFN": metric_WIFN,
    }

    def get_module_name(layer_idx: int, suffix: str) -> str:
        pattern = f"layers.{layer_idx}.{suffix}"
        for n in target_names:
            if pattern in n:
                return n
        raise KeyError(f"[TALE] Cannot find module name with pattern '{pattern}' in target_names.")

    attn_metric_list, mlp_metric_list = [], []
    attn_baseline_inp_list, mlp_baseline_inp_list = [], []

    for i in tqdm(range(num_layers), desc="[TALE] Computing FLAP metric per layer"):
        layer = layers[i]
        attn_name = get_module_name(i, "self_attn.o_proj")
        mlp_name = get_module_name(i, "mlp.down_proj")

        # ATTENTION
        if not wanda_sp:
            W_attn = metrics_dict[metrics](attn_name) ** 2
            attn_baseline = wrapped_layers[attn_name].baseline_inp.type(torch.float16)
        else:
            weight_attn = dequantize(all_linears[attn_name])
            W_attn = (
                torch.abs(weight_attn.data) *
                torch.sqrt(wrapped_layers[attn_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            attn_baseline = wrapped_layers[attn_name].scaler_row.type(torch.float16)

        attn_metric_list.append(W_attn.cpu())
        attn_baseline_inp_list.append(attn_baseline.cpu())

        # MLP
        if not wanda_sp:
            W_mlp = metrics_dict[metrics](mlp_name)
            mlp_baseline = wrapped_layers[mlp_name].baseline_inp.type(torch.float16)
        else:
            weight_mlp = dequantize(all_linears[mlp_name])
            W_mlp = (
                torch.abs(weight_mlp.data) *
                torch.sqrt(wrapped_layers[mlp_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            mlp_baseline = wrapped_layers[mlp_name].scaler_row.type(torch.float16)

        mlp_metric_list.append(W_mlp.cpu())
        mlp_baseline_inp_list.append(mlp_baseline.cpu())

        # 释放
        if hasattr(wrapped_layers[attn_name], 'free'):
            wrapped_layers[attn_name].free()
        if hasattr(wrapped_layers[mlp_name], 'free'):
            wrapped_layers[mlp_name].free()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()

    attn_metric = torch.stack(attn_metric_list)   # [L, H_attn_input]
    mlp_metric = torch.stack(mlp_metric_list)     # [L, H_mlp_input]

    return attn_metric, mlp_metric, attn_baseline_inp_list, mlp_baseline_inp_list


# ============================================================
# ===================== Loss 计算函数 =========================
# ============================================================

@torch.no_grad()
def evaluate_loss(model, device, dataloader, max_batches: int = 10) -> float:
    model.eval()
    losses = []
    for i, (inp, tar) in enumerate(dataloader):
        if i >= max_batches:
            break
        inp = inp.to(device)
        tar = tar.to(device)
        out = model(input_ids=inp, labels=tar)
        loss = out.loss.detach().float().item()
        losses.append(loss)
    if len(losses) == 0:
        return float("inf")
    return sum(losses) / len(losses)


# ============================================================
# ========= Layer 冗余度 / 候选层池 / 剪枝算子 =================
# ============================================================

def compute_layer_redundancy_scores(
    attn_metric: torch.Tensor,
    mlp_metric: torch.Tensor,
    bottom_q: float = 0.2,
) -> np.ndarray:
    """
    attn_metric: [L, H_attn]
    mlp_metric:  [L, H_mlp]
    返回 layer_score[L]，越小表示越冗余。
    """
    L = attn_metric.size(0)
    scores = []
    for l in range(L):
        a = attn_metric[l]
        m = mlp_metric[l]
        k_a = max(1, int(a.numel() * bottom_q))
        k_m = max(1, int(m.numel() * bottom_q))
        bottom_a = torch.topk(a, k=k_a, largest=False).values
        bottom_m = torch.topk(m, k=k_m, largest=False).values
        layer_score = torch.cat([bottom_a, bottom_m]).mean().item()
        scores.append(layer_score)
    return np.array(scores)  # 小的更冗余


def get_decoder_layers(model):
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        return model.model.encoder.layers
    else:
        raise RuntimeError("[TALE] Cannot find decoder layers in model.")


# === 记录已剪 head / neuron，避免重复剪同一个 ===

class WidthState:
    def __init__(self, num_layers: int, num_heads: int, mlp_dim: int):
        # 每层：head 是否已剪
        self.pruned_heads = [set() for _ in range(num_layers)]
        # 每层：mlp neuron 是否已剪
        self.pruned_neurons = [set() for _ in range(num_layers)]
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim


# === 应用 & 恢复 layer 删除（通过 patch forward） ===

def patch_layer_skip(layer):
    """
    将 decoder layer 的 forward 改成恒等映射：直接返回输入 hidden_states。
    假定 LLaMA 风格：forward(hidden_states, attention_mask, position_ids, ...).
    """
    original_forward = layer.forward

    def forward_skip(*args, **kwargs):
        # HF LlamaDecoderLayer 的 forward 第一个位置参数始终是 hidden_states
        # 不管后面跟了多少东西，我们都原样把 hidden_states 传下去即可
        if len(args) == 0:
            # 理论上不会发生，只是防御性写法
            raise RuntimeError("[TALE] forward_skip got no positional args.")
        hidden_states = args[0]
        return hidden_states

    layer.forward = forward_skip
    return original_forward


def restore_layer_forward(layer, original_forward):
    layer.forward = original_forward


# === 应用 / 恢复 width 剪枝（简单版：置零权重，不做 bias 补偿） ===

def apply_head_prune(layer, head_indices, hidden_size, num_heads):
    """
    在给定 decoder layer 上剪掉若干 attention heads：
    - 对 q_proj/k_proj/v_proj：对应 head 的行置零（head_dim 分段）
    - 对 o_proj：对应 head 的输入列置零
    """
    head_dim = hidden_size // num_heads
    with torch.no_grad():
        q = layer.self_attn.q_proj
        k = layer.self_attn.k_proj
        v = layer.self_attn.v_proj
        o = layer.self_attn.o_proj

        for h in head_indices:
            if h < 0 or h >= num_heads:
                continue
            start = h * head_dim
            end = (h + 1) * head_dim
            # q/k/v: out_features x in_features
            q.weight.data[start:end, :] = 0
            k.weight.data[start:end, :] = 0
            v.weight.data[start:end, :] = 0
            # o: out_features x in_features，heads 在 in_features 方向
            o.weight.data[:, start:end] = 0


def apply_mlp_neuron_prune(layer, neuron_indices, hidden_size, mlp_dim):
    """
    在给定 decoder layer 上剪掉若干 FFN neurons：
    - gate_proj / up_proj: 对应行置零
    - down_proj: 对应列置零
    """
    with torch.no_grad():
        gate = layer.mlp.gate_proj
        up = layer.mlp.up_proj
        down = layer.mlp.down_proj
        for n in neuron_indices:
            if n < 0 or n >= mlp_dim:
                continue
            gate.weight.data[n, :] = 0
            up.weight.data[n, :] = 0
            down.weight.data[:, n] = 0


# ============================================================
# ==================== 迭代剪枝核心逻辑 ======================
# ============================================================
def tale_iterative_pruning(args, model, tokenizer, device):
    """
    TALE 迭代剪枝主过程（修正版）：
    - 显式维护 active_layers，避免重复删同一层、重复记账；
    - candidate_layers 只从 active_layers 中选；
    - 宽度剪枝同样只在 active layer 上尝试；
    - ckpt 保存路径修正。
    """
    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    mlp_dim = model.config.intermediate_size

    # ---- 统一在这里安全地拿 num_heads ----
    if hasattr(model.config, "num_attention_heads"):
        num_heads = model.config.num_attention_heads
    else:
        attn_mod = layers[0].self_attn
        if hasattr(attn_mod, "num_heads"):
            num_heads = attn_mod.num_heads
        else:
            hidden_size = getattr(model.config, "hidden_size", attn_mod.o_proj.in_features)
            head_dim = getattr(model.config, "head_dim", 128)
            num_heads = hidden_size // head_dim

    print(f"[INFO] Detected num_heads = {num_heads}")

    print(f"[TALE] Model has {num_layers} layers, hidden={hidden_size}, mlp_dim={mlp_dim}, heads={num_heads}")
    total_params = count_model_params(model)
    print(f"[TALE] Total params = {total_params/1e6:.2f}M")

    # === 维护每一层是否 still active（尚未被 depth 剪掉） ===
    active_layers = [True] * num_layers

    # 准备校准数据（用于 Loss 评估）
    calib_loader, _ = get_loaders(
        args.calib_dataset, nsamples=args.num_loss_sample,
        seed=args.seed, tokenizer=tokenizer,
        seqlen=args.seqlen, data_dir=args.data_dir
    )
    print(f"[TALE] Loss calibration loader ready, #batches={len(calib_loader)}")

    base_loss = evaluate_loss(model, device, calib_loader, max_batches=args.num_loss_sample)
    print(f"[TALE] Baseline loss on calib set: {base_loss:.4f}")

    # 初次计算 FLAP importance
    attn_metric, mlp_metric, attn_baseline_list, mlp_baseline_list = compute_flap_importance(
        model=model,
        tokenizer=tokenizer,
        device=device,
        calib_dataset=args.calib_dataset,
        data_dir=args.data_dir,
        num_calib_sample=args.num_calib_sample,
        seqlen=args.seqlen,
        metrics=args.metrics,
        wanda_sp=args.wanda_sp,
    )
    print("[TALE] Initial FLAP importance computed.")

    # 初始化 WidthState
    width_state = WidthState(num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim)

    # 记录历史 score，用于 stage 切换
    history_scores: List[float] = []
    history_best_depth: List[float] = []
    history_best_width: List[float] = []

    total_removed_params = 0
    sparsity_ratio = 0.0

    stage = "coarse"
    stop_depth = False
    stop_width = False

    # 根据用户指定的剪枝模式，预先关闭一边的动作
    if getattr(args, "prune_mode", "both") == "depth":
        stop_width = True
        print("[TALE] Prune mode = DEPTH ONLY (no width pruning, 类 shortgpt).")
    elif args.prune_mode == "width":
        stop_depth = True
        print("[TALE] Prune mode = WIDTH ONLY (no depth pruning, 类 FLAP).")
    else:
        print("[TALE] Prune mode = BOTH (TALE depth+width).")
        
    it = 0

    # 已经为哪些“整数剪枝率”保存过模型，例如 17 表示 17.xxx%
    saved_sparsity_levels = set()
    # 如果你只关心从某个百分比开始保存，可以改这个阈值，比如 17.0
    min_save_percent = 0.0

    while sparsity_ratio < args.target_sparsity and it < args.max_iters:
        it += 1
        print("=" * 80)
        print(f"[TALE] Iteration {it} | current sparsity={sparsity_ratio*100:.2f}% | stage={stage} | "
              f"stop_depth={stop_depth} | stop_width={stop_width}")

        # 每轮更新当前 baseline loss（基于“已经永久剪过”的模型）
        base_loss_iter = evaluate_loss(model, device, calib_loader, max_batches=args.num_loss_sample)
        print(f"[TALE] Current baseline loss (iter {it}): {base_loss_iter:.4f}")

        # ---- 如果所有层都已经被 depth 删光了，就直接退出 ----
        active_indices = np.where(np.array(active_layers))[0]
        if len(active_indices) == 0:
            print("[TALE] No active layers remaining for depth/width pruning. Stop.")
            break
        # ---- 只允许删除2-29层之间的层 ----
        allowed_min = 2
        allowed_max = num_layers - 2
        mask = (active_indices >= allowed_min) & (active_indices <= allowed_max)
        active_indices = active_indices[mask]

        # 根据当前 importance 计算每层冗余度
        layer_scores = compute_layer_redundancy_scores(attn_metric, mlp_metric, bottom_q=args.layer_bottom_q)

        # 只在 active_layers 里选出最冗余的若干层
        active_layer_scores = layer_scores[active_indices]
        order_in_active = np.argsort(active_layer_scores)
        ordered_active_layers = active_indices[order_in_active]
        K = min(args.K_layer_candidates, len(ordered_active_layers))
        candidate_layers = ordered_active_layers[:K]
        print(f"[TALE] Active layers: {active_indices.tolist()}")
        print(f"[TALE] Candidate layers (most redundant among active): {candidate_layers.tolist()}")

        CandidateActions = []

        # ---------------- depth candidates ----------------
        if not stop_depth:
            from itertools import combinations

            # —— 所有阶段都考虑“单层”删层 —— 
            for l in candidate_layers:
                l = int(l)
                if not active_layers[l]:
                    continue
                ΔLoss, ΔParams, score = simulate_remove_layers(
                    model, device, calib_loader, layers_to_remove=[l],
                    hidden_size=hidden_size, num_heads=num_heads,
                    mlp_dim=mlp_dim, max_batches=args.num_loss_sample,
                    base_loss=base_loss_iter,
                )
                CandidateActions.append({
                    "type": "remove_layers",
                    "layers": [l],
                    "ΔLoss": ΔLoss,
                    "ΔParams": ΔParams,
                    "score": score,
                })

            # —— 只有在 coarse 阶段额外考虑“两层组合”删层 —— 
            if stage == "coarse":
                for (l1, l2) in combinations(candidate_layers, 2):
                    l1, l2 = int(l1), int(l2)
                    if not (active_layers[l1] and active_layers[l2]):
                        continue
                    ΔLoss, ΔParams, score = simulate_remove_layers(
                        model, device, calib_loader,
                        layers_to_remove=[l1, l2],  # ★ 这里改成 [l1, l2]
                        hidden_size=hidden_size, num_heads=num_heads,
                        mlp_dim=mlp_dim, max_batches=args.num_loss_sample,
                        base_loss=base_loss_iter,
                    )
                    CandidateActions.append({
                        "type": "remove_layers",
                        "layers": [l1, l2],
                        "ΔLoss": ΔLoss,
                        "ΔParams": ΔParams,
                        "score": score,
                    })

        # ---------------- width candidates ----------------
        if (not stop_width):
            for l in candidate_layers:
                l = int(l)
                if not active_layers[l]:
                    continue  # depth 已经剪掉的层，不再做 width

                # 粗/细阶段的每轮宽度步长
                if stage == "coarse":
                    n_head_to_prune = max(1, int(args.p_width_coarse * num_heads))
                    n_mlp_to_prune = max(1, int(args.p_width_coarse * mlp_dim))
                else:
                    n_head_to_prune = max(1, int(args.p_width_fine * num_heads))
                    n_mlp_to_prune = max(1, int(args.p_width_fine * mlp_dim))

                # ---- ATTENTION width ----
                attn_scores = attn_metric[l]  # [hidden_size]
                head_dim = hidden_size // num_heads
                head_importance = attn_scores.reshape(num_heads, head_dim).sum(dim=1)  # [num_heads]

                available_heads = [h for h in range(num_heads) if h not in width_state.pruned_heads[l]]
                if len(available_heads) > 0:
                    head_scores = [(h, head_importance[h].item()) for h in available_heads]
                    head_candidates = sorted(head_scores, key=lambda x: x[1])[:n_head_to_prune]
                    head_candidates = [h for h, _ in head_candidates]

                    if len(head_candidates) > 0:
                        ΔLoss, ΔParams, score = simulate_prune_heads(
                            model, device, calib_loader,
                            layer_idx=l,
                            head_indices=head_candidates,
                            hidden_size=hidden_size,
                            num_heads=num_heads,
                            max_batches=args.num_loss_sample,
                            base_loss=base_loss_iter,
                        )
                        CandidateActions.append({
                            "type": "prune_attn",
                            "layer": l,
                            "heads": head_candidates,
                            "ΔLoss": ΔLoss,
                            "ΔParams": ΔParams,
                            "score": score,
                        })

                # ---- MLP width ----
                mlp_scores = mlp_metric[l]  # [mlp_dim]
                available_neurons = [n for n in range(mlp_dim) if n not in width_state.pruned_neurons[l]]
                if len(available_neurons) > 0:
                    neuron_scores = [(n, mlp_scores[n].item()) for n in available_neurons]
                    neuron_candidates = sorted(neuron_scores, key=lambda x: x[1])[:n_mlp_to_prune]
                    neuron_candidates = [n for n, _ in neuron_candidates]

                    if len(neuron_candidates) > 0:
                        ΔLoss, ΔParams, score = simulate_prune_mlp_neurons(
                            model, device, calib_loader,
                            layer_idx=l,
                            neuron_indices=neuron_candidates,
                            hidden_size=hidden_size,
                            mlp_dim=mlp_dim,
                            max_batches=args.num_loss_sample,
                            base_loss=base_loss_iter,
                        )
                        CandidateActions.append({
                            "type": "prune_mlp",
                            "layer": l,
                            "neurons": neuron_candidates,
                            "ΔLoss": ΔLoss,
                            "ΔParams": ΔParams,
                            "score": score,
                        })

        if len(CandidateActions) == 0:
            print("[TALE] No more valid candidate actions, stop.")
            break

        # 找出 depth / width 各自最好的分数（用于冗余度判断）
        best_depth_score = float("inf")
        best_width_score = float("inf")
        for a in CandidateActions:
            if a["type"] == "remove_layers":
                best_depth_score = min(best_depth_score, a["score"])
            else:
                best_width_score = min(best_width_score, a["score"])

        history_best_depth.append(best_depth_score)
        history_best_width.append(best_width_score)

        # 选全局最优动作
        best_action = min(CandidateActions, key=lambda x: x["score"])
        history_scores.append(best_action["score"])

        print(f"[TALE] Selected action: {best_action}")
        print(f"[TALE] ΔLoss={best_action['ΔLoss']:.6f}, ΔParams={best_action['ΔParams']}, "
              f"score={best_action['score']:.6e}")

        # 真正应用动作（永久修改模型）
        if best_action["type"] == "remove_layers":
            layers_to_remove = best_action["layers"]

            # 只对尚未被删掉的层做永久 skip，并更新 active_layers
            real_new_layers = []
            for l in layers_to_remove:
                l = int(l)
                if not active_layers[l]:
                    continue
                print(f"[TALE] Permanently skipping layer {l}")
                patch_layer_skip(layers[l])  # 不再恢复
                active_layers[l] = False
                real_new_layers.append(l)

            if len(real_new_layers) == 0:
                print("[TALE] Warning: best_action remove_layers but no new active layer was actually removed.")

        elif best_action["type"] == "prune_attn":
            l = best_action["layer"]
            heads = best_action["heads"]
            print(f"[TALE] Permanently pruning attention heads at layer {l}: {heads}")
            apply_head_prune(layers[l], heads, hidden_size, num_heads)
            for h in heads:
                width_state.pruned_heads[l].add(h)

        elif best_action["type"] == "prune_mlp":
            l = best_action["layer"]
            neurons = best_action["neurons"]
            print(f"[TALE] Permanently pruning MLP neurons at layer {l}: {neurons}")
            apply_mlp_neuron_prune(layers[l], neurons, hidden_size, mlp_dim)
            for n in neurons:
                width_state.pruned_neurons[l].add(n)
        else:
            raise ValueError(f"[TALE] Unknown action type: {best_action['type']}")

        # 这里仍然使用 best_action["ΔParams"] 作为近似计数，
        # 因为我们已经保证不会对 inactive 层再次构造 depth 动作，
        # 所以不会出现 double-count。
        total_removed_params += best_action["ΔParams"]
        sparsity_ratio = total_removed_params / total_params
        print(f"[TALE] Total removed params (theoretical) = {total_removed_params/1e6:.2f}M "
              f"({sparsity_ratio*100:.2f}% of original)")

        # 视情况重算 FLAP importance（可设置大一点，避免太慢）
        if it % args.recompute_interval == 0:
            print("[TALE] Recomputing FLAP importance after several iterations ...")
            attn_metric, mlp_metric, attn_baseline_list, mlp_baseline_list = compute_flap_importance(
                model=model,
                tokenizer=tokenizer,
                device=device,
                calib_dataset=args.calib_dataset,
                data_dir=args.data_dir,
                num_calib_sample=args.num_calib_sample,
                seqlen=args.seqlen,
                metrics=args.metrics,
                wanda_sp=args.wanda_sp,
            )

        # -------- Stage switching: coarse -> fine ----------
        if stage == "coarse" and len(history_scores) >= args.M_window:
            early_scores = history_scores[:args.M_window]
            baseline = np.median(early_scores)
            recent_scores = history_scores[-args.M_window:]
            recent_median = np.median(recent_scores)
            print(f"[TALE] Stage-switch check: recent_median={recent_median:.3e}, "
                  f"baseline={baseline:.3e}, ratio={recent_median / (baseline + 1e-12):.2f}")
            if recent_median > args.tau1 * baseline:
                stage = "fine"
                print("[TALE] Switch to fine-grained pruning stage.")

        # -------- Decide when to stop depth / width ----------
        if len(history_best_depth) >= args.M_window:
            recent_depth = history_best_depth[-args.M_window:]
            recent_width = history_best_width[-args.M_window:]
            min_depth = min(recent_depth)
            min_width = min(recent_width)
            print(f"[TALE] Depth/Width best score window: depth={min_depth:.3e}, width={min_width:.3e}")

            if (not stop_depth) and (min_depth > args.tau2 * (min_width + 1e-12)):
                stop_depth = True
                print("[TALE] Stop pruning depth; width still has much more redundancy.")
            if (not stop_width) and (min_width > args.tau2 * (min_depth + 1e-12)):
                stop_width = True
                print("[TALE] Stop pruning width; depth still has much more redundancy.")

        # -------- 按剪枝率等级保存一次模型 --------
        # 当前剪枝率（百分比）
        current_percent = sparsity_ratio * 100.0
        # 只在超过一个最小阈值后才考虑保存，比如从 17% 开始
        if current_percent >= min_save_percent:
            # 取整数百分比，例如 17.3% -> 17
            level = int(current_percent // 1)
            if level not in saved_sparsity_levels:
                # 第一次进入这个 [level, level+1) 区间，保存一次
                print(f"[TALE][ckpt] First time reaching sparsity >= {level}%, saving checkpoint...")
                ckpt_path = save_hf_checkpoint(model, tokenizer, args.output_dir, it, sparsity_ratio)
                save_pruning_state(ckpt_path, active_layers, width_state)
                saved_sparsity_levels.add(level)
        # -------- 隔一个step保存一次模型 --------
        # ckpt_path = save_hf_checkpoint(model, tokenizer, args.output_dir, it, sparsity_ratio)
        # save_pruning_state(ckpt_path, active_layers, width_state)


    print("[TALE] Iterative pruning finished.")
    return


# ============================================================
# ============== 模拟动作：ΔLoss / ΔParams ===================
# ============================================================
@torch.no_grad()
def simulate_remove_layers(
    model,
    device,
    calib_loader,
    layers_to_remove: List[int],
    hidden_size: int,
    num_heads: int,
    mlp_dim: int,
    max_batches: int,
    base_loss: float,
) -> Tuple[float, int, float]:
    """
    通过临时 patch forward 来模拟删层：
    - ΔLoss = loss_after - base_loss（是真正的“损失变化”）
    - score = ΔLoss / ΔParams
    """
    layers = get_decoder_layers(model)
    original_forwards = {}
    for l in layers_to_remove:
        original_forwards[l] = layers[l].forward
        patch_layer_skip(layers[l])

    loss_after = evaluate_loss(model, device, calib_loader, max_batches=max_batches)

    # 恢复
    for l in layers_to_remove:
        restore_layer_forward(layers[l], original_forwards[l])

    ΔLoss = loss_after - base_loss

    params_per_layer = 4 * hidden_size * hidden_size + 3 * hidden_size * mlp_dim
    ΔParams = params_per_layer * len(layers_to_remove)

    ΔLoss_eff = max(ΔLoss, 0.0)  # 把负数截断
    score = ΔLoss_eff / max(ΔParams, 1)
    return ΔLoss, ΔParams, score


@torch.no_grad()
def simulate_prune_heads(
    model,
    device,
    calib_loader,
    layer_idx: int,
    head_indices: List[int],
    hidden_size: int,
    num_heads: int,
    max_batches: int,
    base_loss: float,
) -> Tuple[float, int, float]:
    """
    模拟剪掉若干 attention heads：
    - ΔLoss = loss_after - base_loss
    """
    layers = get_decoder_layers(model)
    layer = layers[layer_idx]
    if not hasattr(layer, "self_attn"):
        return float("inf"), 0, float("inf")

    attn = layer.self_attn
    head_dim = hidden_size // num_heads

    backup = {}
    for name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        if hasattr(attn, name):
            proj = getattr(attn, name)
            backup[name] = proj.weight.data.clone()

    # 临时置 0
    for name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        if hasattr(attn, name):
            proj = getattr(attn, name)
            W = proj.weight.data
            for h in head_indices:
                start = h * head_dim
                end = (h + 1) * head_dim
                if W.size(0) >= end:
                    W[start:end, :] = 0

    loss_after = evaluate_loss(model, device, calib_loader, max_batches=max_batches)

    # 恢复
    for name in backup:
        getattr(attn, name).weight.data.copy_(backup[name])

    ΔLoss = loss_after - base_loss

    params_per_head = 4 * hidden_size * head_dim  # q/k/v/o
    ΔParams = params_per_head * len(head_indices)

    ΔLoss_eff = max(ΔLoss, 0.0)
    score = ΔLoss_eff / max(ΔParams, 1)
    return ΔLoss, ΔParams, score


@torch.no_grad()
def simulate_prune_mlp_neurons(
    model,
    device,
    calib_loader,
    layer_idx: int,
    neuron_indices: List[int],
    hidden_size: int,
    mlp_dim: int,
    max_batches: int,
    base_loss: float,
) -> Tuple[float, int, float]:
    """
    模拟剪掉若干 MLP neurons：
    - ΔLoss = loss_after - base_loss
    """
    layers = get_decoder_layers(model)
    layer = layers[layer_idx]
    if not hasattr(layer, "mlp"):
        return float("inf"), 0, float("inf")

    mlp = layer.mlp

    backup = {}
    for name in ["gate_proj", "up_proj", "down_proj"]:
        if hasattr(mlp, name):
            proj = getattr(mlp, name)
            backup[name] = proj.weight.data.clone()

    # 临时置 0
    for name in ["gate_proj", "up_proj"]:
        if hasattr(mlp, name):
            W = getattr(mlp, name).weight.data
            for n in neuron_indices:
                if n < W.size(0):
                    W[n, :] = 0

    if hasattr(mlp, "down_proj"):
        W = mlp.down_proj.weight.data
        for n in neuron_indices:
            if n < W.size(1):
                W[:, n] = 0

    loss_after = evaluate_loss(model, device, calib_loader, max_batches=max_batches)

    # 恢复
    for name in backup:
        getattr(mlp, name).weight.data.copy_(backup[name])

    ΔLoss = loss_after - base_loss

    params_per_neuron = hidden_size * 2 + hidden_size  # 粗略
    ΔParams = params_per_neuron * len(neuron_indices)

    ΔLoss_eff = max(ΔLoss, 0.0)
    score = ΔLoss_eff / max(ΔParams, 1)
    return ΔLoss, ΔParams, score



# ============================================================
# ========================= CLI 部分 =========================
# ============================================================

def parse_args():
    parser = argparse.ArgumentParser(description="TALE iterative depth+width pruning")

    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)

    parser.add_argument("--calib_dataset", type=str, default="mmlu",
                        choices=["wikitext2", "c4", "mmlu"])
    parser.add_argument("--data_dir", type=str, default=None)

    parser.add_argument("--num_calib_sample", type=int, default=20,
                        help="用于 FLAP 的校准样本数（序列条数）")
    parser.add_argument("--num_loss_sample", type=int, default=8,
                        help="评估 Loss 时使用的 batch 数（越大越准，越小越快）")
    parser.add_argument("--seqlen", type=int, default=512)

    parser.add_argument("--metrics", type=str, default="WIFV",
                        choices=["IFV", "WIFV", "WIFN"])
    parser.add_argument("--wanda_sp", action="store_true")

    # 迭代剪枝超参
    parser.add_argument("--target_sparsity", type=float, default=0.8,
                        help="目标剪枝比例（理论参数层面）")
    parser.add_argument("--max_iters", type=int, default=50,
                        help="最大迭代轮数（安全上限）")

    parser.add_argument("--K_layer_candidates", type=int, default=4)
    parser.add_argument("--layer_bottom_q", type=float, default=0.2)

    parser.add_argument("--p_width_coarse", type=float, default=0.02)
    parser.add_argument("--p_width_fine", type=float, default=0.01)

    parser.add_argument("--tau1", type=float, default=2.0,
                        help="coarse → fine 的 score 放大因子阈值")
    parser.add_argument("--tau2", type=float, default=2.0,
                        help="停止 depth 或 width 的冗余度阈值因子")
    parser.add_argument("--M_window", type=int, default=3,
                        help="判断冗余度与阶段切换时使用的滑动窗口长度")

    parser.add_argument("--recompute_interval", type=int, default=10,
                        help="每隔多少轮重算一次 FLAP importance")

    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument(
        "--prune_mode",
        type=str,
        default="both",
        choices=["both", "depth", "width"],
        help=(
            "剪枝模式："
            "'both' = 同时做 depth+width（原 TALE）；"
            "'depth' = 只删层（接近 shortgpt）；"
            "'width' = 只剪通道（接近 FLAP）。"
        ),
    )


    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"[TALE] Using device: {device}")

    # 加载模型
    dtype = torch.float32
    if args.fp16 and device.type == "cuda":
        dtype = torch.float16
    elif args.bf16 and device.type == "cuda":
        dtype = torch.bfloat16

    print(f"[TALE] Loading model from {args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        dtype=dtype,
        device_map=None,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    tale_iterative_pruning(args, model, tokenizer, device)

    print("[TALE] All done.")


if __name__ == "__main__":
    main()
