#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Iterative FLAP width pruning (only structured width, no depth)
- 假定 LLaMA 系模型: model.model.layers[i].self_attn / mlp
- 用 FLAP 逻辑统计 attn / mlp 通道重要性 (IFV / WIFV / WIFN)
- 迭代剪枝：
    * 每轮从全局重要性最低的 head / neuron 中剪一小部分
    * 使用 mask 置零 (unstructured at weight-level but structured at channel/head-level)
- 剪枝完成后，你自己调用 model.save_pretrained(output_dir)
"""

import os
import argparse
import random
from typing import List, Tuple, Optional, Dict

import torch
import torch.nn as nn
from datasets import load_dataset, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
import numpy as np



# ============================================================
# ===================== 数据加载部分 ==========================
# ============================================================

class TokenizerWrapper:
    def __init__(self, input_ids: torch.Tensor):
        self.input_ids = input_ids


def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def get_c4(nsamples, seed, seqlen, tokenizer, data_dir=None):
    if data_dir is None:
        traindata = load_dataset('allenai/c4', 'allenai--c4', split='train')
        valdata = load_dataset('allenai/c4', 'allenai--c4', split='validation')
    else:
        train_files = {'train': 'en/c4-train.00000-of-01024.json.gz'}
        val_files = {'validation': 'en/c4-validation.00000-of-00008.json.gz'}
        traindata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=train_files,
            split='train'
        )
        valdata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=val_files,
            split='validation'
        )

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)
    return trainloader, valenc


def format_mmlu_example(dataset, include_answer: bool = True):
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data


def get_mmlu_trainenc(
    data_dir: str,
    seed: int,
    nsamples: int,
    tokenizer,
    seqlen: int = 512,
    num_tasks: int = None,
):
    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]
    random.seed(seed)
    if num_tasks is not None:
        random.shuffle(subclass)
        subclass = subclass[:num_tasks]
    keys = ["question", "subject", "choices", "answer"]
    subnum = nsamples // len(subclass)
    extra = nsamples % len(subclass)
    num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
    random.shuffle(num_list)

    traindata = {key: [] for key in keys}
    mmlu_path = os.path.join(data_dir, "mmlu")

    for num, classsname in tqdm(
        list(zip(num_list, subclass)),
        total=len(subclass),
        desc="Loading the subclass in MMLU"
    ):
        if num == 0:
            continue
        try:
            data = load_dataset(mmlu_path, classsname, split="train").shuffle(seed=seed)[:num]
        except Exception:
            data = load_dataset(mmlu_path, classsname, split="validation").shuffle(seed=seed)[:num]
        for key in keys:
            traindata[key].extend(data[key])

    dataset = Dataset.from_dict(traindata)
    all_data = format_mmlu_example(dataset, include_answer=False)

    trainenc = tokenizer(
        all_data,
        return_tensors='pt',
        max_length=seqlen,
        padding='max_length',
        truncation=True,
    )
    trainenc["input_ids"] = trainenc["input_ids"].reshape(1, -1)
    trainenc["attention_mask"] = trainenc["attention_mask"].reshape(1, -1)
    return trainenc


def get_mmlu_loader(
    data_dir: str,
    nsamples: int,
    seed: int,
    seqlen: int,
    tokenizer,
):
    enc = get_mmlu_trainenc(
        data_dir=data_dir,
        seed=seed,
        nsamples=nsamples,
        tokenizer=tokenizer,
        seqlen=seqlen,
    )
    ids = enc["input_ids"]
    total_len = ids.shape[1]
    random.seed(seed)
    loader = []
    for _ in range(nsamples):
        if total_len <= seqlen + 1:
            i = 0
        else:
            i = random.randint(0, total_len - seqlen - 1)
        j = i + seqlen
        inp = ids[:, i:j].clone()
        tar = inp.clone()
        tar[:, :-1] = -100
        loader.append((inp, tar))
    return loader, TokenizerWrapper(ids[:, :seqlen])


def get_loaders(
    name: str,
    nsamples: int = 128,
    seed: int = 0,
    seqlen: int = 2048,
    tokenizer=None,
    data_dir: str = None,
):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if 'c4' in name:
        return get_c4(nsamples, seed, seqlen, tokenizer, data_dir=data_dir)
    if 'mmlu' in name:
        if data_dir is None:
            raise ValueError("Using mmlu as calib dataset requires --data_dir pointing to parent of 'mmlu'.")
        return get_mmlu_loader(data_dir, nsamples, seed, seqlen, tokenizer)
    raise ValueError(f"Unknown dataset name in get_loaders: {name}")


# ============================================================
# ============== FLAP Wrapper & metrics ======================
# ============================================================

# 对应 IFV / WIFV / WIFN
metrics = {
    'IFV': lambda wrapped_layers, subset, name: wrapped_layers[name].fluc_inp,
    'WIFV': lambda wrapped_layers, subset, name:
        wrapped_layers[name].fluc_inp * torch.sum(subset[name].weight.data.pow(2), dim=0),
    'WIFN': lambda wrapped_layers, subset, name:
        (torch.abs(subset[name].weight.data) *
         torch.sqrt(wrapped_layers[name].scaler_inp.reshape((1, -1)))).mean(axis=0),
}


def find_layers(
    module,
    layers=(nn.Linear,),
    name: str = ''
):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
    return res


class BiasGPT:
    def __init__(self, layer, metric: str):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.out_dim = layer.out_features
        self.in_dim = layer.in_features

        self.type = metric
        self.nsamples = 0

        self.baseline_inp = torch.zeros((self.in_dim), device=self.dev)
        if self.type == "WIFN":
            self.scaler_inp = torch.zeros((self.in_dim), device=self.dev)
        else:
            self.fluc_inp = torch.zeros((self.in_dim), device=self.dev)

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        batch_size = inp.shape[0]
        if isinstance(self.layer, nn.Linear):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        old_baseline_inp = self.baseline_inp
        self.baseline_inp *= self.nsamples / (self.nsamples + batch_size)
        self.baseline_inp += torch.mean(inp, dim=1) / (self.nsamples + batch_size)

        if self.type == "WIFN":
            inp32 = inp.type(torch.float32)
            self.scaler_inp *= self.nsamples / (self.nsamples + batch_size)
            self.scaler_inp += torch.norm(inp32, p=2, dim=1) ** 2 / (self.nsamples + batch_size)
        else:
            if self.nsamples == 0:
                self.fluc_inp = torch.zeros_like(self.baseline_inp)
            else:
                self.fluc_inp *= (self.nsamples - 1) / (self.nsamples + batch_size - 1)
                self.fluc_inp += torch.sum(
                    (inp - self.baseline_inp.unsqueeze(1)) *
                    (inp - old_baseline_inp.unsqueeze(1)),
                    dim=1
                ) / (self.nsamples + batch_size)

        self.nsamples += batch_size

    def free(self):
        self.baseline_inp = None
        if hasattr(self, 'fluc_inp'):
            self.fluc_inp = None
        if hasattr(self, 'scaler_inp'):
            self.scaler_inp = None
        torch.cuda.empty_cache()


# ============================================================
# ================== Calibration 输入收集 =====================
# ============================================================

def prepare_calibration_input(model, dataloader, device, nsamples, seqlen):
    """
    仿照原 FLAP：抓取第一层输入 hidden_states，作为后续各层的输入基准。
    """
    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    dtype = next(iter(model.parameters())).dtype
    hidden_size = model.config.hidden_size

    inps = torch.zeros((nsamples, seqlen, hidden_size), dtype=dtype, device=device)
    inps.requires_grad = False
    cache = {'i': 0, 'attention_mask': None, "position_ids": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, inp, **kwargs):
            if cache['i'] < nsamples:
                inps[cache['i']] = inp
                cache['i'] += 1
                cache['attention_mask'] = kwargs.get('attention_mask', None)
                cache['position_ids'] = kwargs.get('position_ids', None)
            raise ValueError

    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        if cache['i'] >= nsamples:
            break
        try:
            model(batch[0].to(device))
        except ValueError:
            pass
    layers[0] = layers[0].module

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']
    position_ids = cache['position_ids']
    model.config.use_cache = use_cache

    return inps, outs, attention_mask, position_ids


# ============================================================
# ====================== compress 函数 ========================
# ============================================================

def compress(layer, attn_mask, mlp_mask, attn_mean_inp, mlp_mean_inp, device, bias=True, unstr=True):
    """
    FLAP 原版压缩函数的精简版：
    - unstr=True: 只做 mask，不改模块形状（推荐迭代模式）
    - unstr=False: 真正裁剪权重和维度（这里保留实现，但不建议和迭代混用）
    """
    # 只有 mask 模式是你现在真正需要的
    if unstr:
        # Attention Weight Masking
        if attn_mask is not None:
            # attn_mask: [num_heads]，True=保留
            head_dim = 128  # 默认 head_dim=128，如有不同需修改
            expand_mask = attn_mask.repeat_interleave(head_dim)  # [hidden_size]
            # q/k/v: [hidden_size, hidden_size]
            layer.self_attn.q_proj.weight.data *= expand_mask.unsqueeze(-1).to(device)
            layer.self_attn.k_proj.weight.data *= expand_mask.unsqueeze(-1).to(device)
            layer.self_attn.v_proj.weight.data *= expand_mask.unsqueeze(-1).to(device)

            output_weight = layer.self_attn.o_proj.weight.data
            if bias and attn_mean_inp is not None:
                # 计算补偿 bias：用被剪掉的部分均值 * W^T
                pruned_mask = (~attn_mask).to(device)
                pruned_mask = pruned_mask.repeat_interleave(head_dim)
                output_bias = ((attn_mean_inp.to(device) * pruned_mask) @ output_weight.T)
                layer.self_attn.o_proj.bias = nn.Parameter(output_bias)
            layer.self_attn.o_proj.weight.data = output_weight

        # MLP Weight Masking
        if mlp_mask is not None:
            layer.mlp.up_proj.weight.data *= mlp_mask.unsqueeze(-1).to(device)
            layer.mlp.gate_proj.weight.data *= mlp_mask.unsqueeze(-1).to(device)

            output_weight = layer.mlp.down_proj.weight.data
            if bias and mlp_mean_inp is not None:
                pruned_mask = (~mlp_mask).to(device)
                output_bias = ((mlp_mean_inp.to(device) * pruned_mask) @ output_weight.T)
                layer.mlp.down_proj.bias = nn.Parameter(output_bias)
            layer.mlp.down_proj.weight.data = output_weight

        torch.cuda.empty_cache()
        return

    # 下面是真正裁剪 shape 的分支，如暂时不用可以不管
    # ---------------------------------------------------------
    if attn_mask is not None:
        head_dim = 128
        expand_mask = attn_mask.repeat_interleave(head_dim)
        keep_idx = torch.where(expand_mask)[0]

        layer.self_attn.q_proj.weight.data = layer.self_attn.q_proj.weight.data[keep_idx]
        layer.self_attn.k_proj.weight.data = layer.self_attn.k_proj.weight.data[keep_idx]
        layer.self_attn.v_proj.weight.data = layer.self_attn.v_proj.weight.data[keep_idx]

        num_out = keep_idx.numel()
        layer.self_attn.q_proj.out_features = num_out
        layer.self_attn.k_proj.out_features = num_out
        layer.self_attn.v_proj.out_features = num_out

        output_weight = layer.self_attn.o_proj.weight.data
        if bias and attn_mean_inp is not None:
            pruned_mask = (~attn_mask).to(device)
            pruned_mask = pruned_mask.repeat_interleave(head_dim)
            output_bias = ((attn_mean_inp.to(device) * pruned_mask) @ output_weight.T)
            layer.self_attn.o_proj.bias = nn.Parameter(output_bias)

        layer.self_attn.o_proj.weight.data = output_weight[:, keep_idx]
        layer.self_attn.o_proj.in_features = keep_idx.numel()

    if mlp_mask is not None:
        keep_idx = torch.where(mlp_mask)[0]
        layer.mlp.up_proj.weight.data = layer.mlp.up_proj.weight.data[keep_idx]
        layer.mlp.gate_proj.weight.data = layer.mlp.gate_proj.weight.data[keep_idx]
        layer.mlp.up_proj.out_features = keep_idx.numel()
        layer.mlp.gate_proj.out_features = keep_idx.numel()
        layer.mlp.intermediate_size = keep_idx.numel()

        output_weight = layer.mlp.down_proj.weight.data
        if bias and mlp_mean_inp is not None:
            pruned_mask = (~mlp_mask).to(device)
            output_bias = ((mlp_mean_inp.to(device) * pruned_mask) @ output_weight.T)
            layer.mlp.down_proj.bias = nn.Parameter(output_bias)

        layer.mlp.down_proj.weight.data = output_weight[:, keep_idx]
        layer.mlp.down_proj.in_features = keep_idx.numel()

    torch.cuda.empty_cache()


# ============================================================
# ============= 计算通道重要性 (一次性) ======================
# ============================================================
def get_decoder_layers(model):
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        return model.model.encoder.layers
    else:
        raise RuntimeError("[TALE] Cannot find decoder layers in model.")

def dequantize(layer_mod: nn.Linear) -> torch.Tensor:
    """
    简化版 dequantize:
    - 当前不考虑 bitsandbytes 之类的量化，只假定是普通 nn.Linear
    - 直接返回它的 weight tensor
    如果以后你要支持 bnb，可以在这里改成真正的反量化逻辑。
    """
    return layer_mod.weight


@torch.no_grad()
def compute_flap_importance(
    model,
    tokenizer,
    device,
    calib_dataset: str,
    data_dir: Optional[str],
    num_calib_sample: int,
    seqlen: int,
    metrics: str = "WIFV",
    wanda_sp: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
    """
    等价于你原来 prune_flap(..., only_importance=True) 的上半部分，
    返回：
        attn_metric: [L, H_attn_input] （未标准化）
        mlp_metric:  [L, H_mlp_input]
        attn_baseline_list: len=L，每个是 [H_attn_input]
        mlp_baseline_list:  len=L，每个是 [H_mlp_input]
    """
    use_cache = getattr(model.config, "use_cache", False)
    model.config.use_cache = False
    model.eval()

    print(f"[TALE] Loading calibration data for FLAP from {calib_dataset}")
    dataloader, _ = get_loaders(
        calib_dataset,
        nsamples=num_calib_sample,
        seed=42,
        tokenizer=tokenizer,
        seqlen=seqlen,
        data_dir=data_dir,
    )
    print(f"[TALE] Calibration dataset ready, #batches={len(dataloader)}")

    all_linears: Dict[str, nn.Module] = find_layers(model)
    target_names = [
        name for name in all_linears.keys()
        if name.endswith("self_attn.o_proj") or name.endswith("mlp.down_proj")
    ]
    if len(target_names) == 0:
        raise RuntimeError("[TALE] Cannot find any 'self_attn.o_proj' or 'mlp.down_proj' linear layers.")

    wrapped_layers: Dict[str, BiasGPT] = {}
    for name in target_names:
        if wanda_sp:
            wrapped_layers[name] = WrappedGPT(all_linears[name])
        else:
            wrapped_layers[name] = BiasGPT(all_linears[name], metrics)

    handles = []

    def make_hook(name):
        def hook(module, inp, out):
            x_in = inp[0] if isinstance(inp, (tuple, list)) else inp
            wrapped_layers[name].add_batch(x_in.detach(), out.detach())
        return hook

    for name in target_names:
        handles.append(all_linears[name].register_forward_hook(make_hook(name)))

    print("[TALE] Running calibration forward passes for FLAP importance ...")
    for idx, (inp, _) in enumerate(tqdm(dataloader, desc="FLAP calib")):
        if idx >= num_calib_sample:
            break
        model(inp.to(device))

    for h in handles:
        h.remove()

    # 取得 decoder layers 引用（假定 LLaMA 风格）
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        layers = model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        layers = model.model.encoder.layers
    else:
        raise RuntimeError("[TALE] Cannot find model layers (expected model.model.layers or .model.encoder.layers).")
    num_layers = len(layers)

    def metric_IFV(module_name: str):
        return wrapped_layers[module_name].fluc_inp

    def metric_WIFV(module_name: str):
        layer_mod = all_linears[module_name]
        return wrapped_layers[module_name].fluc_inp * torch.sum(dequantize(layer_mod).data.pow(2), dim=0)

    def metric_WIFN(module_name: str):
        layer_mod = all_linears[module_name]
        return (
            torch.abs(dequantize(layer_mod).data) *
            torch.sqrt(wrapped_layers[module_name].scaler_inp.reshape((1, -1)))
        ).mean(axis=0)

    metrics_dict = {
        "IFV": metric_IFV,
        "WIFV": metric_WIFV,
        "WIFN": metric_WIFN,
    }

    def get_module_name(layer_idx: int, suffix: str) -> str:
        pattern = f"layers.{layer_idx}.{suffix}"
        for n in target_names:
            if pattern in n:
                return n
        raise KeyError(f"[TALE] Cannot find module name with pattern '{pattern}' in target_names.")

    attn_metric_list, mlp_metric_list = [], []
    attn_baseline_inp_list, mlp_baseline_inp_list = [], []

    for i in tqdm(range(num_layers), desc="[TALE] Computing FLAP metric per layer"):
        layer = layers[i]
        attn_name = get_module_name(i, "self_attn.o_proj")
        mlp_name = get_module_name(i, "mlp.down_proj")

        # ATTENTION
        if not wanda_sp:
            W_attn = metrics_dict[metrics](attn_name) ** 2
            attn_baseline = wrapped_layers[attn_name].baseline_inp.type(torch.float16)
        else:
            weight_attn = dequantize(all_linears[attn_name])
            W_attn = (
                torch.abs(weight_attn.data) *
                torch.sqrt(wrapped_layers[attn_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            attn_baseline = wrapped_layers[attn_name].scaler_row.type(torch.float16)

        attn_metric_list.append(W_attn.cpu())
        attn_baseline_inp_list.append(attn_baseline.cpu())

        # MLP
        if not wanda_sp:
            W_mlp = metrics_dict[metrics](mlp_name)
            mlp_baseline = wrapped_layers[mlp_name].baseline_inp.type(torch.float16)
        else:
            weight_mlp = dequantize(all_linears[mlp_name])
            W_mlp = (
                torch.abs(weight_mlp.data) *
                torch.sqrt(wrapped_layers[mlp_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            mlp_baseline = wrapped_layers[mlp_name].scaler_row.type(torch.float16)

        mlp_metric_list.append(W_mlp.cpu())
        mlp_baseline_inp_list.append(mlp_baseline.cpu())

        # 释放
        if hasattr(wrapped_layers[attn_name], 'free'):
            wrapped_layers[attn_name].free()
        if hasattr(wrapped_layers[mlp_name], 'free'):
            wrapped_layers[mlp_name].free()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()

    attn_metric = torch.stack(attn_metric_list)   # [L, H_attn_input]
    mlp_metric = torch.stack(mlp_metric_list)     # [L, H_mlp_input]

    return attn_metric, mlp_metric, attn_baseline_inp_list, mlp_baseline_inp_list

@torch.no_grad()
def compute_flap_metrics_once(args, model, tokenizer, device):
    """
    用已有的 compute_flap_importance 做一次 FLAP 指标计算，
    然后：
      - 把 attention 这边的通道重要性聚合成 per-head 的分数 [L, num_heads]
      - MLP 部分保持 per-neuron 的分数 [L, mlp_dim]
    返回：
      attn_metric_heads: [L, num_heads]
      mlp_metric_neurons: [L, mlp_dim]
      attn_baseline_inp_list: len=L, 每层一个 [hidden_size]
      mlp_baseline_inp_list: len=L, 每层一个 [mlp_dim]
    """
    # 利用你已经写好的 compute_flap_importance（内部就是 hook + model(...)）
    attn_metric, mlp_metric, attn_baseline_inp_list, mlp_baseline_inp_list = compute_flap_importance(
        model=model,
        tokenizer=tokenizer,
        device=device,
        calib_dataset=args.calib_dataset,
        data_dir=args.data_dir,
        num_calib_sample=args.num_calib_sample,
        seqlen=args.seqlen,
        metrics=args.metrics
    )

    layers = get_decoder_layers(model)
    num_layers = len(layers)

    hidden_size = model.config.hidden_size
    if hasattr(model.config, "num_attention_heads"):
        num_heads = model.config.num_attention_heads
    else:
        # 兜底：从第一层的 self_attn 推一个
        attn_mod = layers[0].self_attn
        if hasattr(attn_mod, "num_heads"):
            num_heads = attn_mod.num_heads
        else:
            head_dim = getattr(model.config, "head_dim", 128)
            num_heads = hidden_size // head_dim

    head_dim = hidden_size // num_heads

    # attn_metric 原本是 [L, hidden_size]，切成 [L, num_heads, head_dim] 后按 head 聚合
    attn_metric = attn_metric.view(num_layers, num_heads, head_dim)
    attn_metric_heads = attn_metric.sum(dim=-1)        # [L, num_heads]

    # mlp_metric 本来就是 per-neuron 的指标 [L, mlp_dim]
    mlp_metric_neurons = mlp_metric                    # [L, mlp_dim]

    return attn_metric_heads, mlp_metric_neurons, attn_baseline_inp_list, mlp_baseline_inp_list

# ============================================================
# =================== 迭代式 FLAP 剪枝 =======================
# ============================================================

def iterative_flap_prune(args, model, tokenizer, device=torch.device("cuda:0")):
    """
    迭代式 FLAP：
      - 只做 head + neuron 剪枝（不删层）；
      - 每一轮只剪一小部分重要性最低的 head/neuron（全局最小）；
      - 默认 unstr=True，用 mask 置零，不改模块形状（结构安全）。
    """
    use_cache = model.config.use_cache
    model.config.use_cache = False

    # 1) 先算一次全局的重要性
    (
        attn_metric_heads,      # [L, num_heads]
        mlp_metric_neurons,     # [L, mlp_dim]
        attn_baseline_inp_list,
        mlp_baseline_inp_list,
    ) = compute_flap_metrics_once(args, model, tokenizer, device)

    layers = model.model.layers
    L, num_heads = attn_metric_heads.shape
    _, mlp_dim = mlp_metric_neurons.shape

    print(f"[FLAP-Iter] L={L}, num_heads={num_heads}, mlp_dim={mlp_dim}")

    # 2) 初始化“保留 mask”：一开始所有 head/neuron 都保留
    attn_keep = torch.ones_like(attn_metric_heads, dtype=torch.bool)  # [L, num_heads]
    mlp_keep = torch.ones_like(mlp_metric_neurons, dtype=torch.bool)  # [L, mlp_dim]

    total_heads = L * num_heads
    total_neurons = L * mlp_dim

    target_ratio = float(args.target_pruning_ratio)  # e.g. 0.5
    max_pruned_heads = int(total_heads * target_ratio)
    max_pruned_neurons = int(total_neurons * target_ratio)

    print(
        f"[FLAP-Iter] target_ratio={target_ratio:.2f}, "
        f"max_pruned_heads={max_pruned_heads}, max_pruned_neurons={max_pruned_neurons}"
    )

    pruned_heads = 0
    pruned_neurons = 0

    step_ratio = float(args.step_pruning_ratio)  # 每轮剪剩余 多少比例
    max_iters = int(args.max_iters)

    it = 0

    # ========= 新增：按 target_pruning_ratio 和 10% 步长生成保存阈值 =========
    save_step = 0.1  # 每 10%
    max_save = target_ratio
    # 例如 target_ratio=0.5 -> [0.1, 0.2, 0.3, 0.4, 0.5]
    save_levels = [round(x, 3) for x in np.arange(save_step, max_save + 1e-6, save_step)]
    saved_levels = set()
    print(f"[FLAP-Iter] Will save checkpoints at sparsity levels: {save_levels}")

    while it < max_iters:
        it += 1
        print("=" * 60)
        print(
            f"[FLAP-Iter] Iter {it}, "
            f"pruned_heads={pruned_heads}/{max_pruned_heads}, "
            f"pruned_neurons={pruned_neurons}/{max_pruned_neurons}"
        )

        if pruned_heads >= max_pruned_heads and pruned_neurons >= max_pruned_neurons:
            print("[FLAP-Iter] Reached target pruning ratio, stop.")
            break

        remaining_heads = total_heads - pruned_heads
        remaining_neurons = total_neurons - pruned_neurons

        step_head = max(1, int(remaining_heads * step_ratio))
        step_neuron = max(1, int(remaining_neurons * step_ratio))

        step_head = min(step_head, max_pruned_heads - pruned_heads)
        step_neuron = min(step_neuron, max_pruned_neurons - pruned_neurons)

        if step_head < 0:
            step_head = 0
        if step_neuron < 0:
            step_neuron = 0

        # ---------- 选要剪的 heads ----------
        masked_head_metric = attn_metric_heads.clone()
        masked_head_metric[~attn_keep] = float("inf")
        flat_head_metric = masked_head_metric.view(-1)

        if step_head > 0:
            head_values, head_indices = torch.topk(
                flat_head_metric, k=step_head, largest=False
            )
            layer_idx_head = head_indices // num_heads
            head_idx = head_indices % num_heads
            for lh, hh in zip(layer_idx_head.tolist(), head_idx.tolist()):
                attn_keep[lh, hh] = False
            pruned_heads += step_head
            print(f"[FLAP-Iter] Iter {it}: prune {step_head} heads.")

        # ---------- 选要剪的 neurons ----------
        masked_neuron_metric = mlp_metric_neurons.clone()
        masked_neuron_metric[~mlp_keep] = float("inf")
        flat_neuron_metric = masked_neuron_metric.view(-1)

        if step_neuron > 0:
            neuron_values, neuron_indices = torch.topk(
                flat_neuron_metric, k=step_neuron, largest=False
            )
            layer_idx_neuron = neuron_indices // mlp_dim
            neuron_idx = neuron_indices % mlp_dim
            for ln, nn_ in zip(layer_idx_neuron.tolist(), neuron_idx.tolist()):
                mlp_keep[ln, nn_] = False
            pruned_neurons += step_neuron
            print(f"[FLAP-Iter] Iter {it}: prune {step_neuron} neurons.")

        # ---------- 应用当前 mask 到模型 ----------
        for idx in range(L):
            layer = layers[idx]
            dev = device

            attn_mask_layer = attn_keep[idx]
            mlp_mask_layer = mlp_keep[idx]

            if attn_mask_layer is not None:
                compress(
                    layer,
                    attn_mask_layer.to(dev),
                    None,
                    attn_baseline_inp_list[idx].to(dev),
                    None,
                    dev,
                    bias=True,
                    unstr=args.unstr,
                )

            if mlp_mask_layer is not None:
                compress(
                    layer,
                    None,
                    mlp_mask_layer.to(dev),
                    None,
                    mlp_baseline_inp_list[idx].to(dev),
                    dev,
                    bias=True,
                    unstr=args.unstr,
                )

        torch.cuda.empty_cache()

        cur_ratio_heads = pruned_heads / total_heads
        cur_ratio_neurons = pruned_neurons / total_neurons
        overall_ratio = 0.5 * (cur_ratio_heads + cur_ratio_neurons)

        print(
            f"[FLAP-Iter] Iter {it} done. "
            f"head_pruned={cur_ratio_heads:.3f}, "
            f"mlp_pruned={cur_ratio_neurons:.3f}, "
            f"overall≈{overall_ratio:.3f}"
        )

        # ========= 新增：按 10% 阶梯保存 checkpoint =========
        for lv in save_levels:
            # 例如 lv=0.1 对应 10%
            if overall_ratio >= lv and lv not in saved_levels:
                ckpt_dir = save_hf_checkpoint(
                    model,
                    tokenizer,
                    args.output_dir,
                    iter_idx=it,
                    current_sparsity=overall_ratio,
                )
                print(
                    f"[FLAP-Iter] Saved checkpoint at ~{lv*100:.1f}% sparsity "
                    f"to {ckpt_dir}"
                )
                saved_levels.add(lv)
                break  # 一轮只保存一个 level，避免重复

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
    print("[FLAP-Iter] Iterative FLAP pruning finished.")


# ============================================================
# ========================= CLI 部分 =========================
# ============================================================

def parse_args():
    parser = argparse.ArgumentParser(description="Iterative FLAP width pruning")

    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)

    parser.add_argument("--calib_dataset", type=str, default="wikitext2",
                        choices=["wikitext2", "c4", "mmlu"])
    parser.add_argument("--data_dir", type=str, default=None)

    parser.add_argument("--nsamples", type=int, default=128,
                        help="用于 FLAP 的校准样本数（序列条数）")
    parser.add_argument("--seqlen", type=int, default=2048,
                        help="校准序列长度")
    parser.add_argument("--metrics", type=str, default="WIFV",
                        choices=["IFV", "WIFV", "WIFN"])
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--target_pruning_ratio", type=float, default=0.5,
                        help="最终总剪枝比例（0~1），简单按通道个数近似。")
    parser.add_argument("--step_pruning_ratio", type=float, default=0.05,
                        help="每一轮在剩余通道里剪掉的比例（0~1）。")
    parser.add_argument("--max_iters", type=int, default=50,
                        help="迭代最多轮数（安全上限）。")
    parser.add_argument("--unstr", action="store_true",
                        help="只做 mask，不改变层结构。建议迭代时开启。")
    parser.add_argument("--num_calib_sample", type=int, default=512)
                        
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--bf16", action="store_true")

    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"[FLAP-Iter] Using device: {device}")

    # 加载模型
    dtype = torch.float32
    if args.fp16 and device.type == "cuda":
        dtype = torch.float16
    elif args.bf16 and device.type == "cuda":
        dtype = torch.bfloat16

    print(f"[FLAP-Iter] Loading model from {args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=dtype,
        device_map=None,
    ).to(device)

    # 有些 FLAP 代码里用到 model.seqlen，保险起见设置一下
    model.seqlen = args.seqlen

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    # 执行迭代式 FLAP 剪枝
    iterative_flap_prune(args, model, tokenizer, device)

    # 剪枝结束，保存模型
    print(f"[FLAP-Iter] Saving pruned model to {args.output_dir}")
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print("[FLAP-Iter] All done.")


if __name__ == "__main__":
    main()
