#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
TALE: Task-Aware Layer+width iterative pruning (v0.1)

本脚本做的事情：
- 加载一个 HF CausalLM 模型（假定 LLaMA 系结构：model.model.layers[i].self_attn / mlp）
- 用你已有的 FLAP 逻辑计算一次 attn/mlp 的通道重要性（WIFV / IFV / WIFN）
- 在当前模型上迭代进行 depth + width 剪枝：
    - 每一轮构造候选动作：
        - depth: 删除 1 层或 2 层（前期），删除 1 层（后期）
        - width: 在若干候选 layer 中，对 attn / mlp 剪一小部分通道
    - 对每个候选动作，用 ΔLoss / ΔParams 评分（ΔLoss 在校准集上真实 forward）
    - 选择 score 最小的动作，真正修改模型参数
    - 记录并保存每个 sparsity 阶段的模型 checkpoint（dense model family）
- 当：
    - ΔLoss/ΔParams 相比早期 baseline 连续若干轮明显变大 → 进入 "fine" 小步剪枝阶段
    - depth 或 width 一侧在最近若干轮都比另一侧 "代价更高" → 冻结该侧，只剪另一侧

⚠ 重要假设：
- 模型是 LLaMA 风格，decoder 层位于 model.model.layers
- 每层含：
    - layer.self_attn.q_proj/k_proj/v_proj/o_proj (nn.Linear, 无 bias 或 bias 较简单)
    - layer.mlp.gate_proj / up_proj / down_proj
- 你当前环境已经安装 bitsandbytes / datasets / transformers

你可以按需把这个脚本拆成多个 py 文件，这里先写成一个方便你整体拷贝。
"""

import os
import sys
import re
import argparse
import random
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from datasets import load_dataset, Dataset
from tqdm import tqdm

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)

import bitsandbytes as bnb
from peft.tuners.lora.bnb import Linear4bit, Linear8bitLt
from peft.utils.integrations import dequantize_bnb_weight


# ============================================================
# ===================== 数据加载部分 ==========================
# ============================================================

class TokenizerWrapper:
    def __init__(self, input_ids: torch.Tensor):
        self.input_ids = input_ids


def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def get_c4(nsamples, seed, seqlen, tokenizer, data_dir=None):
    if data_dir is None:
        traindata = load_dataset('allenai/c4', 'allenai--c4', split='train')
        valdata = load_dataset('allenai/c4', 'allenai--c4', split='validation')
    else:
        train_files = {'train': 'en/c4-train.00000-of-01024.json.gz'}
        val_files = {'validation': 'en/c4-validation.00000-of-00008.json.gz'}
        traindata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=train_files,
            split='train'
        )
        valdata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=val_files,
            split='validation'
        )

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)
    return trainloader, valenc


def format_mmlu_example(dataset, include_answer: bool = True):
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data


def get_mmlu_trainenc(
    data_dir: str,
    seed: int,
    nsamples: int,
    tokenizer,
    seqlen: int = 512,
    num_tasks: int = None,
):
    from tqdm import tqdm as tqdm_local
    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]
    random.seed(seed)
    if num_tasks is not None:
        random.shuffle(subclass)
        subclass = subclass[:num_tasks]
    keys = ["question", "subject", "choices", "answer"]
    subnum = nsamples // len(subclass)
    extra = nsamples % len(subclass)
    num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
    random.shuffle(num_list)

    traindata = {key: [] for key in keys}
    mmlu_path = os.path.join(data_dir, "mmlu")

    for num, classsname in tqdm_local(
        zip(num_list, subclass),
        total=len(subclass),
        desc="Loading the subclass in MMLU"
    ):
        if num == 0:
            continue
        try:
            data = load_dataset(mmlu_path, classsname, split="train").shuffle(seed=seed)[:num]
        except Exception:
            data = load_dataset(mmlu_path, classsname, split="validation").shuffle(seed=seed)[:num]
        for key in keys:
            traindata[key].extend(data[key])

    dataset = Dataset.from_dict(traindata)
    all_data = format_mmlu_example(dataset, include_answer=False)

    trainenc = tokenizer(
        all_data,
        return_tensors='pt',
        max_length=seqlen,
        padding='max_length',
        truncation=True,
    )
    trainenc["input_ids"] = trainenc["input_ids"].reshape(1, -1)
    trainenc["attention_mask"] = trainenc["attention_mask"].reshape(1, -1)
    return trainenc


def get_mmlu_loader(
    data_dir: str,
    nsamples: int,
    seed: int,
    seqlen: int,
    tokenizer,
):
    enc = get_mmlu_trainenc(
        data_dir=data_dir,
        seed=seed,
        nsamples=nsamples,
        tokenizer=tokenizer,
        seqlen=seqlen,
    )
    ids = enc["input_ids"]
    total_len = ids.shape[1]
    random.seed(seed)
    loader = []
    for _ in range(nsamples):
        if total_len <= seqlen + 1:
            i = 0
        else:
            i = random.randint(0, total_len - seqlen - 1)
        j = i + seqlen
        inp = ids[:, i:j].clone()
        tar = inp.clone()
        tar[:, :-1] = -100
        loader.append((inp, tar))
    return loader, TokenizerWrapper(ids[:, :seqlen])


def get_loaders(
    name: str,
    nsamples: int = 128,
    seed: int = 0,
    seqlen: int = 2048,
    tokenizer=None,
    data_dir: str = None,
):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if 'c4' in name:
        return get_c4(nsamples, seed, seqlen, tokenizer, data_dir=data_dir)
    if 'mmlu' in name:
        if data_dir is None:
            raise ValueError("Using mmlu as calib dataset requires --data_dir pointing to parent of 'mmlu'.")
        return get_mmlu_loader(data_dir, nsamples, seed, seqlen, tokenizer)
    raise ValueError(f"Unknown dataset name in get_loaders: {name}")


# ============================================================
# ============== WrappedGPT & BiasGPT（FLAP） =================
# ============================================================

class WrappedGPT:
    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.rows = layer.out_features
        self.columns = layer.in_features

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp

        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples


class BiasGPT:
    def __init__(self, layer, metric):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.out_dim = layer.out_features
        self.in_dim = layer.in_features

        self.type = metric
        self.nsamples = 0

        self.baseline_inp = torch.zeros((self.in_dim), device=self.dev)
        if self.type == "WIFN":
            self.scaler_inp = torch.zeros((self.in_dim), device=self.dev)
        else:
            self.fluc_inp = torch.zeros((self.in_dim), device=self.dev)

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        batch_size = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        old_baseline_inp = self.baseline_inp
        self.baseline_inp *= self.nsamples / (self.nsamples + batch_size)
        self.baseline_inp += torch.mean(inp, dim=1) / (self.nsamples + batch_size)

        if self.type == "WIFN":
            inp32 = inp.type(torch.float32)
            self.scaler_inp *= self.nsamples / (self.nsamples + batch_size)
            self.scaler_inp += torch.norm(inp32, p=2, dim=1) ** 2 / (self.nsamples + batch_size)
        else:
            if self.nsamples == 0:
                self.fluc_inp = torch.zeros_like(self.baseline_inp)
            else:
                self.fluc_inp *= (self.nsamples - 1) / (self.nsamples + batch_size - 1)
                self.fluc_inp += torch.sum(
                    (inp - self.baseline_inp.unsqueeze(1)) *
                    (inp - old_baseline_inp.unsqueeze(1)),
                    dim=1
                ) / (self.nsamples + batch_size)

        self.nsamples += batch_size

    def free(self):
        self.baseline_inp = None
        if hasattr(self, 'fluc_inp'):
            self.fluc_inp = None
        if hasattr(self, 'scaler_inp'):
            self.scaler_inp = None
        torch.cuda.empty_cache()


# ============================================================
# ===================== 其他工具函数 ==========================
# ============================================================

def save_hf_checkpoint(model, tokenizer, output_dir, iter_idx, current_sparsity):
    step_dir = os.path.join(
        output_dir,
        f"tale_step_{iter_idx:04d}_s{current_sparsity:.3f}"
    )
    os.makedirs(step_dir, exist_ok=True)

    print(f"[TALE] Saving HF-style checkpoint to {step_dir}")
    model.save_pretrained(step_dir)
    tokenizer.save_pretrained(step_dir)
    return step_dir

def find_layers(
    module,
    layers=(nn.Linear, Linear4bit, Linear8bitLt, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt),
    name: str = ''
):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        if 'lora' in name:
            continue
        res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
    return res


def dequantize(layer):
    if hasattr(layer, "get_base_layer"):
        weight = layer.get_base_layer().weight
    else:
        weight = layer.weight
    if hasattr(weight, "quant_state") and weight.quant_state is not None:
        dequant_weight = dequantize_bnb_weight(weight, state=weight.quant_state)
        return dequant_weight
    else:
        return weight


def count_model_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


# ============================================================
# ============= FLAP-style importance 计算 ====================
# ============================================================

@torch.no_grad()
def compute_flap_importance(
    model,
    tokenizer,
    device,
    calib_dataset: str,
    data_dir: Optional[str],
    num_calib_sample: int,
    seqlen: int,
    metrics: str = "WIFV",
    wanda_sp: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
    """
    等价于你原来 prune_flap(..., only_importance=True) 的上半部分，
    返回：
        attn_metric: [L, H_attn_input] （未标准化）
        mlp_metric:  [L, H_mlp_input]
        attn_baseline_list: len=L，每个是 [H_attn_input]
        mlp_baseline_list:  len=L，每个是 [H_mlp_input]
    """
    use_cache = getattr(model.config, "use_cache", False)
    model.config.use_cache = False
    model.eval()

    print(f"[TALE] Loading calibration data for FLAP from {calib_dataset}")
    dataloader, _ = get_loaders(
        calib_dataset,
        nsamples=num_calib_sample,
        seed=42,
        tokenizer=tokenizer,
        seqlen=seqlen,
        data_dir=data_dir,
    )
    print(f"[TALE] Calibration dataset ready, #batches={len(dataloader)}")

    all_linears: Dict[str, nn.Module] = find_layers(model)
    target_names = [
        name for name in all_linears.keys()
        if name.endswith("self_attn.o_proj") or name.endswith("mlp.down_proj")
    ]
    if len(target_names) == 0:
        raise RuntimeError("[TALE] Cannot find any 'self_attn.o_proj' or 'mlp.down_proj' linear layers.")

    wrapped_layers: Dict[str, BiasGPT] = {}
    for name in target_names:
        if wanda_sp:
            wrapped_layers[name] = WrappedGPT(all_linears[name])
        else:
            wrapped_layers[name] = BiasGPT(all_linears[name], metrics)

    handles = []

    def make_hook(name):
        def hook(module, inp, out):
            x_in = inp[0] if isinstance(inp, (tuple, list)) else inp
            wrapped_layers[name].add_batch(x_in.detach(), out.detach())
        return hook

    for name in target_names:
        handles.append(all_linears[name].register_forward_hook(make_hook(name)))

    print("[TALE] Running calibration forward passes for FLAP importance ...")
    for idx, (inp, _) in enumerate(tqdm(dataloader, desc="FLAP calib")):
        if idx >= num_calib_sample:
            break
        model(inp.to(device))

    for h in handles:
        h.remove()

    # 取得 decoder layers 引用（假定 LLaMA 风格）
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        layers = model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        layers = model.model.encoder.layers
    else:
        raise RuntimeError("[TALE] Cannot find model layers (expected model.model.layers or .model.encoder.layers).")
    num_layers = len(layers)

    def metric_IFV(module_name: str):
        return wrapped_layers[module_name].fluc_inp

    def metric_WIFV(module_name: str):
        layer_mod = all_linears[module_name]
        return wrapped_layers[module_name].fluc_inp * torch.sum(dequantize(layer_mod).data.pow(2), dim=0)

    def metric_WIFN(module_name: str):
        layer_mod = all_linears[module_name]
        return (
            torch.abs(dequantize(layer_mod).data) *
            torch.sqrt(wrapped_layers[module_name].scaler_inp.reshape((1, -1)))
        ).mean(axis=0)

    metrics_dict = {
        "IFV": metric_IFV,
        "WIFV": metric_WIFV,
        "WIFN": metric_WIFN,
    }

    def get_module_name(layer_idx: int, suffix: str) -> str:
        pattern = f"layers.{layer_idx}.{suffix}"
        for n in target_names:
            if pattern in n:
                return n
        raise KeyError(f"[TALE] Cannot find module name with pattern '{pattern}' in target_names.")

    attn_metric_list, mlp_metric_list = [], []
    attn_baseline_inp_list, mlp_baseline_inp_list = [], []

    for i in tqdm(range(num_layers), desc="[TALE] Computing FLAP metric per layer"):
        layer = layers[i]
        attn_name = get_module_name(i, "self_attn.o_proj")
        mlp_name = get_module_name(i, "mlp.down_proj")

        # ATTENTION
        if not wanda_sp:
            W_attn = metrics_dict[metrics](attn_name) ** 2
            attn_baseline = wrapped_layers[attn_name].baseline_inp.type(torch.float16)
        else:
            weight_attn = dequantize(all_linears[attn_name])
            W_attn = (
                torch.abs(weight_attn.data) *
                torch.sqrt(wrapped_layers[attn_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            attn_baseline = wrapped_layers[attn_name].scaler_row.type(torch.float16)

        attn_metric_list.append(W_attn.cpu())
        attn_baseline_inp_list.append(attn_baseline.cpu())

        # MLP
        if not wanda_sp:
            W_mlp = metrics_dict[metrics](mlp_name)
            mlp_baseline = wrapped_layers[mlp_name].baseline_inp.type(torch.float16)
        else:
            weight_mlp = dequantize(all_linears[mlp_name])
            W_mlp = (
                torch.abs(weight_mlp.data) *
                torch.sqrt(wrapped_layers[mlp_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            mlp_baseline = wrapped_layers[mlp_name].scaler_row.type(torch.float16)

        mlp_metric_list.append(W_mlp.cpu())
        mlp_baseline_inp_list.append(mlp_baseline.cpu())

        # 释放
        if hasattr(wrapped_layers[attn_name], 'free'):
            wrapped_layers[attn_name].free()
        if hasattr(wrapped_layers[mlp_name], 'free'):
            wrapped_layers[mlp_name].free()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()

    attn_metric = torch.stack(attn_metric_list)   # [L, H_attn_input]
    mlp_metric = torch.stack(mlp_metric_list)     # [L, H_mlp_input]

    return attn_metric, mlp_metric, attn_baseline_inp_list, mlp_baseline_inp_list


# ============================================================
# ===================== Loss 计算函数 =========================
# ============================================================

@torch.no_grad()
def evaluate_loss(model, device, dataloader, max_batches: int = 10) -> float:
    model.eval()
    losses = []
    for i, (inp, tar) in enumerate(dataloader):
        if i >= max_batches:
            break
        inp = inp.to(device)
        tar = tar.to(device)
        out = model(input_ids=inp, labels=tar)
        loss = out.loss.detach().float().item()
        losses.append(loss)
    if len(losses) == 0:
        return float("inf")
    return sum(losses) / len(losses)


# ============================================================
# ========= Layer 冗余度 / 候选层池 / 剪枝算子 =================
# ============================================================

def compute_layer_redundancy_scores(
    attn_metric: torch.Tensor,
    mlp_metric: torch.Tensor,
    bottom_q: float = 0.2,
) -> np.ndarray:
    """
    attn_metric: [L, H_attn]
    mlp_metric:  [L, H_mlp]
    返回 layer_score[L]，越小表示越冗余。
    """
    L = attn_metric.size(0)
    scores = []
    for l in range(L):
        a = attn_metric[l]
        m = mlp_metric[l]
        k_a = max(1, int(a.numel() * bottom_q))
        k_m = max(1, int(m.numel() * bottom_q))
        bottom_a = torch.topk(a, k=k_a, largest=False).values
        bottom_m = torch.topk(m, k=k_m, largest=False).values
        layer_score = torch.cat([bottom_a, bottom_m]).mean().item()
        scores.append(layer_score)
    return np.array(scores)  # 小的更冗余


def get_decoder_layers(model):
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        return model.model.encoder.layers
    else:
        raise RuntimeError("[TALE] Cannot find decoder layers in model.")


# === 记录已剪 head / neuron，避免重复剪同一个 ===

class WidthState:
    def __init__(self, num_layers: int, num_heads: int, mlp_dim: int):
        # 每层：head 是否已剪
        self.pruned_heads = [set() for _ in range(num_layers)]
        # 每层：mlp neuron 是否已剪
        self.pruned_neurons = [set() for _ in range(num_layers)]
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim


# === 应用 & 恢复 layer 删除（通过 patch forward） ===

def patch_layer_skip(layer):
    """
    将 decoder layer 的 forward 改成恒等映射：直接返回输入 hidden_states。
    假定 LLaMA 风格：forward(hidden_states, attention_mask, position_ids, ...).
    """
    original_forward = layer.forward

    def forward_skip(*args, **kwargs):
        # HF LlamaDecoderLayer 的 forward 第一个位置参数始终是 hidden_states
        # 不管后面跟了多少东西，我们都原样把 hidden_states 传下去即可
        if len(args) == 0:
            # 理论上不会发生，只是防御性写法
            raise RuntimeError("[TALE] forward_skip got no positional args.")
        hidden_states = args[0]
        return hidden_states

    layer.forward = forward_skip
    return original_forward


def restore_layer_forward(layer, original_forward):
    layer.forward = original_forward


# === 应用 / 恢复 width 剪枝（简单版：置零权重，不做 bias 补偿） ===

def apply_head_prune(layer, head_indices, hidden_size, num_heads):
    """
    在给定 decoder layer 上剪掉若干 attention heads：
    - 对 q_proj/k_proj/v_proj：对应 head 的行置零（head_dim 分段）
    - 对 o_proj：对应 head 的输入列置零
    """
    head_dim = hidden_size // num_heads
    with torch.no_grad():
        q = layer.self_attn.q_proj
        k = layer.self_attn.k_proj
        v = layer.self_attn.v_proj
        o = layer.self_attn.o_proj

        for h in head_indices:
            if h < 0 or h >= num_heads:
                continue
            start = h * head_dim
            end = (h + 1) * head_dim
            # q/k/v: out_features x in_features
            q.weight.data[start:end, :] = 0
            k.weight.data[start:end, :] = 0
            v.weight.data[start:end, :] = 0
            # o: out_features x in_features，heads 在 in_features 方向
            o.weight.data[:, start:end] = 0


def apply_mlp_neuron_prune(layer, neuron_indices, hidden_size, mlp_dim):
    """
    在给定 decoder layer 上剪掉若干 FFN neurons：
    - gate_proj / up_proj: 对应行置零
    - down_proj: 对应列置零
    """
    with torch.no_grad():
        gate = layer.mlp.gate_proj
        up = layer.mlp.up_proj
        down = layer.mlp.down_proj
        for n in neuron_indices:
            if n < 0 or n >= mlp_dim:
                continue
            gate.weight.data[n, :] = 0
            up.weight.data[n, :] = 0
            down.weight.data[:, n] = 0


# ============================================================
# ==================== 迭代剪枝核心逻辑 ======================
# ============================================================
def tale_iterative_pruning(args, model, tokenizer, device):
    """
    TALE 迭代剪枝主过程（修正版，满足：
      - 不再比较 depth vs width “谁更冗余”；
      - 每一轮最多只删除一层；
      - 逻辑上：前期删层 + 删宽度，后期只删宽度；
      - width 一定是“在每层内部选最不重要的 head/neuron 再剪”。
    """
    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    mlp_dim = model.config.intermediate_size

    # ---- 统一获取 num_heads ----
    if hasattr(model.config, "num_attention_heads"):
        num_heads = model.config.num_attention_heads
    else:
        attn_mod = layers[0].self_attn
        if hasattr(attn_mod, "num_heads"):
            num_heads = attn_mod.num_heads
        else:
            hidden_size = getattr(model.config, "hidden_size", attn_mod.o_proj.in_features)
            head_dim = getattr(model.config, "head_dim", 128)
            num_heads = hidden_size // head_dim

    print(f"[INFO] Detected num_heads = {num_heads}")
    print(f"[TALE] Model has {num_layers} layers, hidden={hidden_size}, mlp_dim={mlp_dim}, heads={num_heads}")
    total_params = count_model_params(model)
    print(f"[TALE] Total params = {total_params/1e6:.2f}M")

    # 维护哪些层 still active（没被 depth 删掉）
    active_layers = [True] * num_layers

    # ---- 校准数据（算 loss 用） ----
    calib_loader, _ = get_loaders(
        args.calib_dataset, nsamples=args.num_loss_sample,
        seed=args.seed, tokenizer=tokenizer,
        seqlen=args.seqlen, data_dir=args.data_dir
    )
    print(f"[TALE] Loss calibration loader ready, #batches={len(calib_loader)}")

    base_loss = evaluate_loss(model, device, calib_loader, max_batches=args.num_loss_sample)
    print(f"[TALE] Baseline loss on calib set: {base_loss:.4f}")

    # ---- 初次计算 FLAP importance ----
    attn_metric, mlp_metric, attn_baseline_list, mlp_baseline_list = compute_flap_importance(
        model=model,
        tokenizer=tokenizer,
        device=device,
        calib_dataset=args.calib_dataset,
        data_dir=args.data_dir,
        num_calib_sample=args.num_calib_sample,
        seqlen=args.seqlen,
        metrics=args.metrics,
        wanda_sp=args.wanda_sp,
    )
    print("[TALE] Initial FLAP importance computed.")

    # width 状态（每层哪些 head / neuron 已经剪过）
    width_state = WidthState(num_layers=num_layers, num_heads=num_heads, mlp_dim=mlp_dim)

    # 只记录 depth（删层）动作的 score，用于 stage 切换
    history_depth_scores: List[float] = []

    total_removed_params = 0
    sparsity_ratio = 0.0

    stage = "coarse"      # "coarse"：删层 + width 都在动；"fine"：只动 width
    stop_depth = False    # 一旦 True，后面不再删层（只删 width）

    # 只保存“第一次到达某个整数剪枝率”的 checkpoint
    saved_sparsity_levels = set()
    min_save_percent = 0.0  # 如果你只想从 17% 开始存，可以改成 17.0

    # depth stop 的 warmup：在前多少轮不做“depth 是否太疼”的判断
    depth_stop_warmup = max(3, args.M_window)

    it = 0
    while sparsity_ratio < args.target_sparsity and it < args.max_iters:
        it += 1
        print("=" * 80)
        print(f"[TALE] Iteration {it} | current sparsity={sparsity_ratio*100:.2f}% | "
              f"stage={stage} | stop_depth={stop_depth}")

        # ---- 基于当前已剪模型的 baseline loss（用于 ΔLoss）----
        base_loss_iter = evaluate_loss(
            model, device, calib_loader, max_batches=args.num_loss_sample
        )
        print(f"[TALE] Current baseline loss (iter {it}): {base_loss_iter:.4f}")

        # ---- 过滤 active 层 ----
        active_indices = np.where(np.array(active_layers))[0]
        if len(active_indices) == 0:
            print("[TALE] No active layers remaining for depth/width pruning. Stop.")
            break

        # 只允许删 [2, num_layers-2] 区间的层
        allowed_min = 2
        allowed_max = num_layers - 2
        mask = (active_indices >= allowed_min) & (active_indices <= allowed_max)
        active_indices = active_indices[mask]
        if len(active_indices) == 0:
            print("[TALE] No active layers in allowed range, stop.")
            break

        # ---- 基于 FLAP 的 per-layer 冗余度（不和 width 比，只用来排序候选层）----
        layer_scores = compute_layer_redundancy_scores(
            attn_metric, mlp_metric, bottom_q=args.layer_bottom_q
        )
        active_layer_scores = layer_scores[active_indices]
        order_in_active = np.argsort(active_layer_scores)
        ordered_active_layers = active_indices[order_in_active]
        K = min(args.K_layer_candidates, len(ordered_active_layers))
        candidate_layers = ordered_active_layers[:K]
        print(f"[TALE] Active layers: {active_indices.tolist()}")
        print(f"[TALE] Candidate layers (most redundant among active): {candidate_layers.tolist()}")

        # ==================== 1) width 剪枝（每层内部选最小的 head / neuron） ====================
        any_width_pruned = False
        total_width_removed_params = 0

        head_dim = hidden_size // num_heads

        # 粗/细阶段的每轮宽度步长
        if stage == "coarse":
            n_head_to_prune = max(1, int(args.p_width_coarse * num_heads))
            n_mlp_to_prune  = max(1, int(args.p_width_coarse * mlp_dim))
        else:
            n_head_to_prune = max(1, int(args.p_width_fine * num_heads))
            n_mlp_to_prune  = max(1, int(args.p_width_fine * mlp_dim))

        for l in candidate_layers:
            l = int(l)
            if not active_layers[l]:
                continue

            layer = layers[l]

            # ---- 1.1 这一层内部：选最不重要的 heads ----
            attn_scores = attn_metric[l]  # [hidden_size]
            head_importance = attn_scores.view(num_heads, head_dim).sum(dim=1)  # [num_heads]

            available_heads = [
                h for h in range(num_heads)
                if h not in width_state.pruned_heads[l]
            ]
            if len(available_heads) > 0 and n_head_to_prune > 0:
                head_scores = [(h, head_importance[h].item()) for h in available_heads]
                head_scores.sort(key=lambda x: x[1])  # 从小到大，越小越不重要
                chosen_heads = [h for h, _ in head_scores[:n_head_to_prune]]

                if len(chosen_heads) > 0:
                    any_width_pruned = True
                    print(f"[TALE][width] layer {l} prune heads: {chosen_heads}")
                    apply_head_prune(layer, chosen_heads, hidden_size, num_heads)
                    width_state.pruned_heads[l].update(chosen_heads)

                    params_per_head = 4 * hidden_size * head_dim  # q/k/v/o
                    total_width_removed_params += params_per_head * len(chosen_heads)

            # ---- 1.2 这一层内部：选最不重要的 MLP neurons ----
            mlp_scores = mlp_metric[l]  # [mlp_dim]
            available_neurons = [
                n for n in range(mlp_dim)
                if n not in width_state.pruned_neurons[l]
            ]
            if len(available_neurons) > 0 and n_mlp_to_prune > 0:
                neuron_scores = [(n, mlp_scores[n].item()) for n in available_neurons]
                neuron_scores.sort(key=lambda x: x[1])  # 从小到大
                chosen_neurons = [n for n, _ in neuron_scores[:n_mlp_to_prune]]

                if len(chosen_neurons) > 0:
                    any_width_pruned = True
                    print(f"[TALE][width] layer {l} prune mlp neurons: {chosen_neurons}")
                    apply_mlp_neuron_prune(layer, chosen_neurons, hidden_size, mlp_dim)
                    width_state.pruned_neurons[l].update(chosen_neurons)

                    params_per_neuron = 3 * hidden_size  # gate/up/down 三个矩阵的近似参数量
                    total_width_removed_params += params_per_neuron * len(chosen_neurons)

        if any_width_pruned:
            total_removed_params += total_width_removed_params
            sparsity_ratio = total_removed_params / total_params
            print(f"[TALE][width] Removed params (this iter, width only) "
                  f"= {total_width_removed_params/1e6:.2f}M, "
                  f"total sparsity={sparsity_ratio*100:.2f}%")

        # ==================== 2) depth：每轮最多只删一层 ====================
        best_action = None

        if (not stop_depth) and (stage in ["coarse", "fine"]):  # fine 阶段一般已经 stop_depth=True
            depth_candidates = []

            for l in candidate_layers:
                l = int(l)
                if not active_layers[l]:
                    continue

                ΔLoss, ΔParams, score = simulate_remove_layers(
                    model, device, calib_loader,
                    layers_to_remove=[l],   # 关键：每次只考虑 “单层删层”
                    hidden_size=hidden_size, num_heads=num_heads,
                    mlp_dim=mlp_dim, max_batches=args.num_loss_sample,
                    base_loss=base_loss_iter,
                )
                depth_candidates.append({
                    "layer": l,
                    "ΔLoss": ΔLoss,
                    "ΔParams": ΔParams,
                    "score": score,
                })

            if len(depth_candidates) > 0:
                # 选 score 最小的那个层（ΔLoss/ΔParams 最优）
                best_action = min(depth_candidates, key=lambda x: x["score"])
                history_depth_scores.append(best_action["score"])

                l_rm = int(best_action["layer"])
                print(f"[TALE] Selected depth action: remove layer {l_rm}")
                print(f"[TALE] ΔLoss={best_action['ΔLoss']:.6f}, ΔParams={best_action['ΔParams']}, "
                      f"score={best_action['score']:.6e}")

                if active_layers[l_rm]:
                    print(f"[TALE] Permanently skipping layer {l_rm}")
                    patch_layer_skip(layers[l_rm])   # 不再恢复
                    active_layers[l_rm] = False

                    total_removed_params += best_action["ΔParams"]
                    sparsity_ratio = total_removed_params / total_params
                    print(f"[TALE] Total removed params (depth+width) = {total_removed_params/1e6:.2f}M "
                          f"({sparsity_ratio*100:.2f}% of original)")
                else:
                    print(f"[TALE] Warning: chosen layer {l_rm} already inactive, skip depth this iter.")
                    best_action = None

        # 如果这一轮既没有 width 剪枝，也没有删层，说明已经没什么可做了，退出防死循环
        if (not any_width_pruned) and (best_action is None):
            print("[TALE] No valid depth or width action in this iteration, stop.")
            break

        # ==================== 3) 视情况重算 FLAP importance ====================
        if it % args.recompute_interval == 0:
            print("[TALE] Recomputing FLAP importance after several iterations ...")
            attn_metric, mlp_metric, attn_baseline_list, mlp_baseline_list = compute_flap_importance(
                model=model,
                tokenizer=tokenizer,
                device=device,
                calib_dataset=args.calib_dataset,
                data_dir=args.data_dir,
                num_calib_sample=args.num_calib_sample,
                seqlen=args.seqlen,
                metrics=args.metrics,
                wanda_sp=args.wanda_sp,
            )

        # ==================== 4) depth 的 stage 切换 & stop_depth 逻辑 ====================
        # 这里只看 depth 的 ΔLoss/ΔParams（history_depth_scores），
        # 不再和 width 比较，满足你说的“冗余率先 layer 后 width”的直觉：
        #   - 前期删层 cost 低 → 多删层；
        #   - 后期删层 cost 明显变高 → 停止删层，只删宽度。
        if (not stop_depth) and len(history_depth_scores) >= args.M_window and it >= depth_stop_warmup:
            early_scores = history_depth_scores[:args.M_window]
            baseline_med = np.median(early_scores)
            recent_scores = history_depth_scores[-args.M_window:]
            recent_med = np.median(recent_scores)

            print(f"[TALE] depth stop-check: recent_med={recent_med:.3e}, "
                  f"baseline_med={baseline_med:.3e}, "
                  f"ratio={recent_med / (baseline_med + 1e-12):.2f}")

            # depth 的 ΔLoss/ΔParams 比“早期 baseline” 放大到 tau2 倍以上，就认为删层已经太疼 → 停止删层
            if recent_med > args.tau2 * baseline_med:
                stop_depth = True
                stage = "fine"
                print("[TALE] >>> Stop DEPTH pruning. Switch to width-only (fine) stage.")

        # ==================== 5) 按剪枝率等级保存 checkpoint ====================
        current_percent = sparsity_ratio * 100.0
        if current_percent >= min_save_percent:
            level = int(current_percent // 1)
            if level not in saved_sparsity_levels:
                print(f"[TALE][ckpt] First time reaching sparsity >= {level}%, saving checkpoint...")
                ckpt_path = save_hf_checkpoint(model, tokenizer, args.output_dir, it, sparsity_ratio)
                saved_sparsity_levels.add(level)

    print("[TALE] Iterative pruning finished.")
    return

# ============================================================
# ============== 模拟动作：ΔLoss / ΔParams ===================
# ============================================================
@torch.no_grad()
def simulate_remove_layers(
    model,
    device,
    calib_loader,
    layers_to_remove: List[int],
    hidden_size: int,
    num_heads: int,
    mlp_dim: int,
    max_batches: int,
    base_loss: float,
) -> Tuple[float, int, float]:
    """
    通过临时 patch forward 来模拟删层：
    - ΔLoss = loss_after - base_loss（是真正的“损失变化”）
    - score = ΔLoss / ΔParams
    """
    layers = get_decoder_layers(model)
    original_forwards = {}
    for l in layers_to_remove:
        original_forwards[l] = layers[l].forward
        patch_layer_skip(layers[l])

    loss_after = evaluate_loss(model, device, calib_loader, max_batches=max_batches)

    # 恢复
    for l in layers_to_remove:
        restore_layer_forward(layers[l], original_forwards[l])

    ΔLoss = loss_after - base_loss

    params_per_layer = 4 * hidden_size * hidden_size + 3 * hidden_size * mlp_dim
    ΔParams = params_per_layer * len(layers_to_remove)

    ΔLoss_eff = max(ΔLoss, 0.0)  # 把负数截断
    score = ΔLoss_eff / max(ΔParams, 1)
    return ΔLoss, ΔParams, score


@torch.no_grad()
def simulate_prune_heads(
    model,
    device,
    calib_loader,
    layer_idx: int,
    head_indices: List[int],
    hidden_size: int,
    num_heads: int,
    max_batches: int,
    base_loss: float,
) -> Tuple[float, int, float]:
    """
    模拟剪掉若干 attention heads：
    - ΔLoss = loss_after - base_loss
    """
    layers = get_decoder_layers(model)
    layer = layers[layer_idx]
    if not hasattr(layer, "self_attn"):
        return float("inf"), 0, float("inf")

    attn = layer.self_attn
    head_dim = hidden_size // num_heads

    backup = {}
    for name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        if hasattr(attn, name):
            proj = getattr(attn, name)
            backup[name] = proj.weight.data.clone()

    # 临时置 0
    for name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
        if hasattr(attn, name):
            proj = getattr(attn, name)
            W = proj.weight.data
            for h in head_indices:
                start = h * head_dim
                end = (h + 1) * head_dim
                if W.size(0) >= end:
                    W[start:end, :] = 0

    loss_after = evaluate_loss(model, device, calib_loader, max_batches=max_batches)

    # 恢复
    for name in backup:
        getattr(attn, name).weight.data.copy_(backup[name])

    ΔLoss = loss_after - base_loss

    params_per_head = 4 * hidden_size * head_dim  # q/k/v/o
    ΔParams = params_per_head * len(head_indices)

    ΔLoss_eff = max(ΔLoss, 0.0)
    score = ΔLoss_eff / max(ΔParams, 1)
    return ΔLoss, ΔParams, score


@torch.no_grad()
def simulate_prune_mlp_neurons(
    model,
    device,
    calib_loader,
    layer_idx: int,
    neuron_indices: List[int],
    hidden_size: int,
    mlp_dim: int,
    max_batches: int,
    base_loss: float,
) -> Tuple[float, int, float]:
    """
    模拟剪掉若干 MLP neurons：
    - ΔLoss = loss_after - base_loss
    """
    layers = get_decoder_layers(model)
    layer = layers[layer_idx]
    if not hasattr(layer, "mlp"):
        return float("inf"), 0, float("inf")

    mlp = layer.mlp

    backup = {}
    for name in ["gate_proj", "up_proj", "down_proj"]:
        if hasattr(mlp, name):
            proj = getattr(mlp, name)
            backup[name] = proj.weight.data.clone()

    # 临时置 0
    for name in ["gate_proj", "up_proj"]:
        if hasattr(mlp, name):
            W = getattr(mlp, name).weight.data
            for n in neuron_indices:
                if n < W.size(0):
                    W[n, :] = 0

    if hasattr(mlp, "down_proj"):
        W = mlp.down_proj.weight.data
        for n in neuron_indices:
            if n < W.size(1):
                W[:, n] = 0

    loss_after = evaluate_loss(model, device, calib_loader, max_batches=max_batches)

    # 恢复
    for name in backup:
        getattr(mlp, name).weight.data.copy_(backup[name])

    ΔLoss = loss_after - base_loss

    params_per_neuron = hidden_size * 2 + hidden_size  # 粗略
    ΔParams = params_per_neuron * len(neuron_indices)

    ΔLoss_eff = max(ΔLoss, 0.0)
    score = ΔLoss_eff / max(ΔParams, 1)
    return ΔLoss, ΔParams, score



# ============================================================
# ========================= CLI 部分 =========================
# ============================================================

def parse_args():
    parser = argparse.ArgumentParser(description="TALE iterative depth+width pruning")

    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)

    parser.add_argument("--calib_dataset", type=str, default="mmlu",
                        choices=["wikitext2", "c4", "mmlu"])
    parser.add_argument("--data_dir", type=str, default=None)

    parser.add_argument("--num_calib_sample", type=int, default=20,
                        help="用于 FLAP 的校准样本数（序列条数）")
    parser.add_argument("--num_loss_sample", type=int, default=8,
                        help="评估 Loss 时使用的 batch 数（越大越准，越小越快）")
    parser.add_argument("--seqlen", type=int, default=512)

    parser.add_argument("--metrics", type=str, default="WIFV",
                        choices=["IFV", "WIFV", "WIFN"])
    parser.add_argument("--wanda_sp", action="store_true")

    # 迭代剪枝超参
    parser.add_argument("--target_sparsity", type=float, default=0.8,
                        help="目标剪枝比例（理论参数层面）")
    parser.add_argument("--max_iters", type=int, default=50,
                        help="最大迭代轮数（安全上限）")

    parser.add_argument("--K_layer_candidates", type=int, default=4)
    parser.add_argument("--layer_bottom_q", type=float, default=0.2)

    parser.add_argument("--p_width_coarse", type=float, default=0.02)
    parser.add_argument("--p_width_fine", type=float, default=0.01)

    parser.add_argument("--tau1", type=float, default=2.0,
                        help="coarse → fine 的 score 放大因子阈值")
    parser.add_argument("--tau2", type=float, default=2.0,
                        help="停止 depth 或 width 的冗余度阈值因子")
    parser.add_argument("--M_window", type=int, default=3,
                        help="判断冗余度与阶段切换时使用的滑动窗口长度")

    parser.add_argument("--recompute_interval", type=int, default=10,
                        help="每隔多少轮重算一次 FLAP importance")

    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"[TALE] Using device: {device}")

    # 加载模型
    dtype = torch.float32
    if args.fp16 and device.type == "cuda":
        dtype = torch.float16
    elif args.bf16 and device.type == "cuda":
        dtype = torch.bfloat16

    print(f"[TALE] Loading model from {args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        dtype=dtype,
        device_map=None,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    tale_iterative_pruning(args, model, tokenizer, device)

    print("[TALE] All done.")


if __name__ == "__main__":
    main()
