#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Standalone FLAP width-pruning / importance script.

功能：
- 加载一个 HF CausalLM 模型（本地或 Hub）
- 从给定数据集（wikitext2 / c4 / mmlu）构造校准样本
- 完整复刻 FLAP 中的 prune_flap 逻辑：
    - metrics: IFV / WIFV / WIFN
    - structure: UL-UM / UL-MM / AL-MM / AL-AM
    - wanda_sp: True / False
    - only_importance: 只返回 attn/mlp importance + baseline
    - 否则返回包含所有 width mask 和 bias 的字典

输出：
- 若 only_importance=True：
    output_dir/attn_metric.npy      # [num_layers, attn_hidden]
    output_dir/mlp_metric.npy       # [num_layers, mlp_hidden]
    output_dir/attn_baseline.npy    # list of arrays
    output_dir/mlp_baseline.npy

- 若 only_importance=False：
    output_dir/width_mask.npy       # dict[name][ratio] -> mask array
    output_dir/bias.npy             # dict[name][ratio] -> bias array
"""

import os
import sys
import re
import argparse
import random
from typing import Dict

import numpy as np
import torch
import torch.nn as nn

from datasets import load_dataset, Dataset
from tqdm import tqdm

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)

import bitsandbytes as bnb
from peft.tuners.lora.bnb import Linear4bit, Linear8bitLt
from peft.utils.integrations import dequantize_bnb_weight


# ===================== 数据加载部分 =====================

class TokenizerWrapper:
    def __init__(self, input_ids: torch.Tensor):
        self.input_ids = input_ids


def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, testenc


def get_c4(nsamples, seed, seqlen, tokenizer, data_dir=None):
    if data_dir is None:
        traindata = load_dataset('allenai/c4', 'allenai--c4', split='train')
        valdata = load_dataset('allenai/c4', 'allenai--c4', split='validation')
    else:
        train_files = {'train': 'en/c4-train.00000-of-01024.json.gz'}
        val_files = {'validation': 'en/c4-validation.00000-of-00008.json.gz'}
        traindata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=train_files,
            split='train'
        )
        valdata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=val_files,
            split='validation'
        )

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)
    return trainloader, valenc


def format_mmlu_example(dataset, include_answer: bool = True):
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data


def get_mmlu_trainenc(
    data_dir: str,
    seed: int,
    nsamples: int,
    tokenizer,
    seqlen: int = 512,
    num_tasks: int = None,
):
    from tqdm import tqdm as tqdm_local
    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]
    random.seed(seed)
    if num_tasks is not None:
        random.shuffle(subclass)
        subclass = subclass[:num_tasks]
    keys = ["question", "subject", "choices", "answer"]
    subnum = nsamples // len(subclass)
    extra = nsamples % len(subclass)
    num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
    random.shuffle(num_list)

    traindata = {key: [] for key in keys}
    mmlu_path = os.path.join(data_dir, "mmlu")

    for num, classsname in tqdm_local(
        zip(num_list, subclass),
        total=len(subclass),
        desc="Loading the subclass in MMLU"
    ):
        if num == 0:
            continue
        try:
            data = load_dataset(mmlu_path, classsname, split="train").shuffle(seed=seed)[:num]
        except Exception:
            data = load_dataset(mmlu_path, classsname, split="validation").shuffle(seed=seed)[:num]
        for key in keys:
            traindata[key].extend(data[key])

    dataset = Dataset.from_dict(traindata)
    all_data = format_mmlu_example(dataset, include_answer=False)

    trainenc = tokenizer(
        all_data,
        return_tensors='pt',
        max_length=seqlen,
        padding='max_length',
        truncation=True,
    )
    trainenc["input_ids"] = trainenc["input_ids"].reshape(1, -1)
    trainenc["attention_mask"] = trainenc["attention_mask"].reshape(1, -1)
    return trainenc


def get_mmlu_loader(
    data_dir: str,
    nsamples: int,
    seed: int,
    seqlen: int,
    tokenizer,
):
    enc = get_mmlu_trainenc(
        data_dir=data_dir,
        seed=seed,
        nsamples=nsamples,
        tokenizer=tokenizer,
        seqlen=seqlen,
    )
    ids = enc["input_ids"]
    total_len = ids.shape[1]
    random.seed(seed)
    loader = []
    for _ in range(nsamples):
        if total_len <= seqlen + 1:
            i = 0
        else:
            i = random.randint(0, total_len - seqlen - 1)
        j = i + seqlen
        inp = ids[:, i:j].clone()
        tar = inp.clone()
        tar[:, :-1] = -100
        loader.append((inp, tar))
    return loader, TokenizerWrapper(ids[:, :seqlen])


def get_loaders(
    name: str,
    nsamples: int = 128,
    seed: int = 0,
    seqlen: int = 2048,
    tokenizer=None,
    data_dir: str = None,
):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if 'c4' in name:
        return get_c4(nsamples, seed, seqlen, tokenizer, data_dir=data_dir)
    if 'mmlu' in name:
        if data_dir is None:
            raise ValueError("Using mmlu as calib dataset requires --data_dir pointing to parent of 'mmlu'.")
        return get_mmlu_loader(data_dir, nsamples, seed, seqlen, tokenizer)
    raise ValueError(f"Unknown dataset name in get_loaders: {name}")


# ===================== WrappedGPT & BiasGPT =====================

class WrappedGPT:
    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.rows = layer.out_features
        self.columns = layer.in_features

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp

        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples


class BiasGPT:
    def __init__(self, layer, metric):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.out_dim = layer.out_features
        self.in_dim = layer.in_features

        self.type = metric
        self.nsamples = 0

        self.baseline_inp = torch.zeros((self.in_dim), device=self.dev)
        if self.type == "WIFN":
            self.scaler_inp = torch.zeros((self.in_dim), device=self.dev)
        else:
            self.fluc_inp = torch.zeros((self.in_dim), device=self.dev)

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        batch_size = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        old_baseline_inp = self.baseline_inp
        self.baseline_inp *= self.nsamples / (self.nsamples + batch_size)
        self.baseline_inp += torch.mean(inp, dim=1) / (self.nsamples + batch_size)

        if self.type == "WIFN":
            inp32 = inp.type(torch.float32)
            self.scaler_inp *= self.nsamples / (self.nsamples + batch_size)
            self.scaler_inp += torch.norm(inp32, p=2, dim=1) ** 2 / (self.nsamples + batch_size)
        else:
            if self.nsamples == 0:
                self.fluc_inp = torch.zeros_like(self.baseline_inp)
            else:
                self.fluc_inp *= (self.nsamples - 1) / (self.nsamples + batch_size - 1)
                self.fluc_inp += torch.sum(
                    (inp - self.baseline_inp.unsqueeze(1)) *
                    (inp - old_baseline_inp.unsqueeze(1)),
                    dim=1
                ) / (self.nsamples + batch_size)

        self.nsamples += batch_size

    def free(self):
        self.baseline_inp = None
        if hasattr(self, 'fluc_inp'):
            self.fluc_inp = None
        if hasattr(self, 'scaler_inp'):
            self.scaler_inp = None
        torch.cuda.empty_cache()


# ===================== 其他工具函数 =====================

def find_layers(
    module,
    layers=(nn.Linear, Linear4bit, Linear8bitLt, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt),
    name: str = ''
):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        if 'lora' in name:
            continue
        res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
    return res


def dequantize(layer):
    if hasattr(layer, "get_base_layer"):
        weight = layer.get_base_layer().weight
    else:
        weight = layer.weight
    if hasattr(weight, "quant_state") and weight.quant_state is not None:
        dequant_weight = dequantize_bnb_weight(weight, state=weight.quant_state)
        return dequant_weight
    else:
        return weight


def cal_remove_neuron(args, model):
    intermediate_size = model.config.intermediate_size
    hidden_size = model.config.hidden_size
    num_layers = model.config.num_hidden_layers
    if args.structure == "UL-MM":
        remove_params = args.pruning_ratio * (intermediate_size * hidden_size * 3 + hidden_size * hidden_size * 4)
        remove_head_params = hidden_size * 4 * (args.remove_heads // num_layers) * 128
        return int((remove_params - remove_head_params) / (hidden_size * 3))
    else:
        remove_params = num_layers * args.pruning_ratio * (intermediate_size * hidden_size * 3 + hidden_size * hidden_size * 4)
        remove_head_params = hidden_size * 4 * args.remove_heads * 128
        return int((remove_params - remove_head_params) / (hidden_size * 3))


# ===================== prune_flap（用 hooks 复刻 FLAP 逻辑） =====================

def prune_flap(args, model, tokenizer, device=torch.device("cuda:0")):
    """
    逻辑与原始 FLAP 一致，但：
    - 不再手动调用单层 forward
    - 统一使用 model(...) + forward hook 收集 self_attn.o_proj / mlp.down_proj 的输入激活
      → 完全避免 position_ids / position_embeddings / rotary_emb 的接口坑
    """
    use_cache = getattr(model.config, "use_cache", False)
    model.config.use_cache = False
    model.eval()

    print(f"loading calibration data from {args.calib_dataset}")
    dataloader, _ = get_loaders(
        args.calib_dataset,
        nsamples=args.num_calib_sample,
        seed=42,
        tokenizer=tokenizer,
        seqlen=args.seqlen,
        data_dir=args.data_dir,
    )
    print("dataset loading complete")

    # ===== 找出所有 Linear，并挑出 self_attn.o_proj & mlp.down_proj =====
    all_linears: Dict[str, nn.Module] = find_layers(model)
    target_names = [
        name for name in all_linears.keys()
        if name.endswith("self_attn.o_proj") or name.endswith("mlp.down_proj")
    ]
    if len(target_names) == 0:
        raise RuntimeError("Cannot find any 'self_attn.o_proj' or 'mlp.down_proj' linear layers in the model.")

    # ===== 为每个目标 Linear 构建 BiasGPT / WrappedGPT 并挂 hook =====
    wrapped_layers: Dict[str, BiasGPT] = {}
    for name in target_names:
        if args.wanda_sp:
            wrapped_layers[name] = WrappedGPT(all_linears[name])
        else:
            wrapped_layers[name] = BiasGPT(all_linears[name], args.metrics)

    handles = []

    def make_hook(name):
        def hook(module, inp, out):
            # inp 是一个 tuple，真正的输入在 inp[0]
            x_in = inp[0] if isinstance(inp, (tuple, list)) else inp
            wrapped_layers[name].add_batch(x_in.detach(), out.detach())
        return hook

    for name in target_names:
        handles.append(all_linears[name].register_forward_hook(make_hook(name)))

    # ===== 跑一遍校准数据（完整 model forward，HF 负责 RoPE 等一切细节） =====
    print("[INFO] Running calibration forward passes ...")
    for idx, (inp, _) in enumerate(dataloader):
        if idx >= args.num_calib_sample:
            break
        with torch.no_grad():
            model(inp.to(device))

    # 解绑 hooks
    for h in handles:
        h.remove()

    # ===== 取得 decoder layers 引用 =====
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        layers = model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        layers = model.model.encoder.layers
    else:
        raise RuntimeError("Cannot find model layers (expected model.model.layers or model.model.encoder.layers).")

    num_layers = len(layers)

    # ===== metric 函数（把 module 也引进来，更通用） =====
    def metric_IFV(module_name: str):
        return wrapped_layers[module_name].fluc_inp

    def metric_WIFV(module_name: str):
        layer_mod = all_linears[module_name]
        return wrapped_layers[module_name].fluc_inp * torch.sum(dequantize(layer_mod).data.pow(2), dim=0)

    def metric_WIFN(module_name: str):
        layer_mod = all_linears[module_name]
        return (
            torch.abs(dequantize(layer_mod).data) *
            torch.sqrt(wrapped_layers[module_name].scaler_inp.reshape((1, -1)))
        ).mean(axis=0)

    metrics = {
        "IFV": metric_IFV,
        "WIFV": metric_WIFV,
        "WIFN": metric_WIFN,
    }

    # ===== 帮助函数：根据 layer_id 拿到该层的 module name =====
    def get_module_name(layer_idx: int, suffix: str) -> str:
        # suffix: "self_attn.o_proj" 或 "mlp.down_proj"
        pattern = f"layers.{layer_idx}.{suffix}"
        for n in target_names:
            if pattern in n:
                return n
        raise KeyError(f"Cannot find module name with pattern '{pattern}' in target_names.")

    # ===== 逐层计算 metric / baseline / （UL-* 时的 mask） =====
    attn_metric_list, mlp_metric_list = [], []
    attn_baseline_inp_list, mlp_baseline_inp_list = [], []
    ul_attn_mask_list, ul_mlp_mask_list = [], []

    for i in tqdm(range(num_layers), desc="Computing FLAP metrics per layer"):
        layer = layers[i]

        attn_name = get_module_name(i, "self_attn.o_proj")
        mlp_name = get_module_name(i, "mlp.down_proj")

        # --------- ATTENTION 部分 ---------
        if not args.wanda_sp:
            W_attn = metrics[args.metrics](attn_name) ** 2
            attn_baseline = wrapped_layers[attn_name].baseline_inp.type(torch.half)
        else:
            weight_attn = dequantize(all_linears[attn_name])
            W_attn = (
                torch.abs(weight_attn.data) *
                torch.sqrt(wrapped_layers[attn_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            attn_baseline = wrapped_layers[attn_name].scaler_row.type(torch.half)

        # 推一下 head_dim，避免死写 128（LLaMA3-8B 依然是 128，不影响你现有实验）
        if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "num_heads"):
            num_heads = layer.self_attn.num_heads
            head_dim = all_linears[attn_name].in_features // num_heads
        else:
            # 退路：保持和原 FLAP 一致
            num_heads = 32
            head_dim = 128

        if args.structure == "UL-UM":
            W_head = W_attn.reshape(-1, head_dim).sum(dim=1)
            thresh = torch.sort(W_head.to(device))[0][
                int(args.pruning_ratio * num_heads)
            ].cpu()
            W_mask = (W_head >= thresh)
            ul_attn_mask_list.append(W_mask)
        elif args.structure == "UL-MM":
            W_head = W_attn.reshape(-1, head_dim).sum(dim=1)
            thresh = torch.sort(W_head.to(device))[0][
                args.remove_heads // num_layers
            ].cpu()
            W_mask = (W_head >= thresh)
            ul_attn_mask_list.append(W_mask)
        else:
            attn_metric_list.append(W_attn.cpu())
        attn_baseline_inp_list.append(attn_baseline)

        # --------- MLP 部分 ---------
        if not args.wanda_sp:
            W_mlp = metrics[args.metrics](mlp_name)
            mlp_baseline = wrapped_layers[mlp_name].baseline_inp.type(torch.half)
        else:
            weight_mlp = dequantize(all_linears[mlp_name])
            W_mlp = (
                torch.abs(weight_mlp.data) *
                torch.sqrt(wrapped_layers[mlp_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            mlp_baseline = wrapped_layers[mlp_name].scaler_row.type(torch.half)

        if args.structure == "UL-UM":
            thresh = torch.sort(W_mlp.to(device))[0][
                int(W_mlp.numel() * args.pruning_ratio)
            ].cpu()
            mlp_mask = (W_mlp >= thresh)
            ul_mlp_mask_list.append(mlp_mask)
        elif args.structure == "UL-MM":
            thresh = torch.sort(W_mlp.to(device))[0][
                cal_remove_neuron(args, model)
            ].cpu()
            mlp_mask = (W_mlp >= thresh)
            ul_mlp_mask_list.append(mlp_mask)
        else:
            mlp_metric_list.append(W_mlp.cpu())
        mlp_baseline_inp_list.append(mlp_baseline)

        # 及时 free（和你原来一样）
        if hasattr(wrapped_layers[attn_name], 'free'):
            wrapped_layers[attn_name].free()
        if hasattr(wrapped_layers[mlp_name], 'free'):
            wrapped_layers[mlp_name].free()

    # ===== 只算重要性模式 =====
    if getattr(args, "only_importance", False):
        if len(attn_metric_list) == 0 or len(mlp_metric_list) == 0:
            raise RuntimeError(
                "only_importance=True 目前只在 structure 为 AL-* 时有意义，"
                "因为 UL-* 分支是直接在 metric 内部阈值化的。"
            )

        attn_metric = torch.stack(attn_metric_list)  # [L, H]
        mlp_metric = torch.stack(mlp_metric_list)    # [L, H2]

        importance = {
            "attn_metric": attn_metric.cpu().numpy(),
            "mlp_metric": mlp_metric.cpu().numpy(),
            "attn_baseline": [x.cpu().numpy() for x in attn_baseline_inp_list],
            "mlp_baseline": [x.cpu().numpy() for x in mlp_baseline_inp_list],
        }

        model.config.use_cache = use_cache
        torch.cuda.empty_cache()
        return importance

    # ===== 下方为原始 FLAP 继续构造 mask + bias 的部分 =====

    def standarlization(x: torch.Tensor) -> torch.Tensor:
        return (x - torch.mean(x, axis=1, keepdim=True)) / torch.std(x, axis=1, keepdim=True)

    mask: Dict[str, Dict[float, np.ndarray]] = {}
    bias: Dict[str, Dict[float, np.ndarray]] = {}

    for name, module in model.named_modules():
        if (
            isinstance(module, (torch.nn.Linear, Linear4bit, Linear8bitLt))
            and 'lm_head' not in name
            and 'lora' not in name
            and 'base_layer' not in name
        ):
            mask[name] = {}
            bias[name] = {}

    pruning_ratio_list = [0.8, 0.65, 0.5]

    for pruning_ratio in pruning_ratio_list:  # 剩余比例
        if args.structure in ["AL-MM", "AL-AM"]:
            attn_metric = torch.stack(attn_metric_list)
            attn_metric = standarlization(attn_metric)
            # 这里依然用 128，与原 FLAP 完全一致（head_dim 已经在前面用到了）
            attn_metric = attn_metric.reshape(num_layers, -1, 128).mean(dim=2)

            mlp_metric = torch.stack(mlp_metric_list)
            mlp_metric = standarlization(mlp_metric)

            if args.structure == "AL-MM":
                sorted_attn = torch.sort(attn_metric.view(-1), descending=True)[0]
                attn_thres = sorted_attn[-int(args.remove_heads)]
                attn_mask = (attn_metric > attn_thres)

                sorted_mlp = torch.sort(mlp_metric.view(-1), descending=True)[0]
                mlp_thres = sorted_mlp[-cal_remove_neuron(args, model)]
                mlp_mask = (mlp_metric > mlp_thres)
            else:  # AL-AM
                prune_metric = torch.cat([attn_metric.view(-1), mlp_metric.view(-1)])
                sorted_prune, indices = torch.sort(prune_metric, descending=True)
                compression_weight = torch.ones_like(indices, dtype=torch.float32)
                # 这里保持原始 512 / 3，不动
                compression_weight[indices < attn_metric.numel()] = 512.0 / 3
                threshold = sorted_prune[
                    torch.argmin(
                        torch.abs(
                            torch.cumsum(compression_weight, 0) -
                            torch.sum(compression_weight) * pruning_ratio
                        )
                    )
                ]
                attn_mask = (attn_metric > threshold)
                mlp_mask = (mlp_metric > threshold)
        else:
            # UL-* 结构：直接用之前 per-layer 阈值化好的 mask
            attn_mask = torch.stack(ul_attn_mask_list)
            mlp_mask = torch.stack(ul_mlp_mask_list)

        mlp_ratio = mlp_mask.view(-1).sum() / mlp_mask.numel()
        attn_ratio = attn_mask.view(-1).sum() / attn_mask.numel()
        print('remaining ratio:', pruning_ratio,
              'mlp non_zero ratio:', mlp_ratio.item(),
              'attn non_zero ratio:', attn_ratio.item())

        for name in mask.keys():
            if 'mlp' in name:
                if 'mlp.up_proj' in name or 'mlp.gate_proj' in name:
                    layer_id = int(re.search(r'layers\.(\d+)', name)[0].removeprefix('layers.'))
                    mask[name][pruning_ratio] = mlp_mask[layer_id].cpu().numpy()
                    bias[name] = None

                elif 'mlp.down_proj' in name:
                    layer_id = int(re.search(r'layers\.(\d+)', name)[0].removeprefix('layers.'))

                    mask[name] = None

                    output_weight = layers[layer_id].mlp.down_proj.weight
                    if getattr(output_weight, "quant_state", None) is not None:
                        output_weight = dequantize_bnb_weight(
                            output_weight, state=getattr(output_weight, "quant_state", None)
                        )

                    output_bias = (
                        (mlp_baseline_inp_list[layer_id] *
                         ~mlp_mask[layer_id].to(device)).to(output_weight) @ output_weight.T
                    )
                    if args.wanda_sp:
                        output_bias = torch.zeros_like(output_bias)
                    bias[name][pruning_ratio] = output_bias.cpu().numpy()
                else:
                    print(f'No such param: {name}, please check the model')
                    sys.exit()

            elif 'self_attn' in name:
                if 'self_attn.q_proj' in name or 'self_attn.k_proj' in name or 'self_attn.v_proj' in name:
                    layer_id = int(re.search(r'layers\.(\d+)', name)[0].removeprefix('layers.'))

                    mask[name][pruning_ratio] = attn_mask[layer_id].repeat_interleave(128).cpu().numpy()
                    bias[name] = None

                elif 'self_attn.o_proj' in name:
                    layer_id = int(re.search(r'layers\.(\d+)', name)[0].removeprefix('layers.'))

                    mask[name] = None

                    output_weight = layers[layer_id].self_attn.o_proj.weight
                    if getattr(output_weight, "quant_state", None) is not None:
                        output_weight = dequantize_bnb_weight(
                            output_weight, state=getattr(output_weight, "quant_state", None)
                        )

                    output_bias = (
                        (attn_baseline_inp_list[layer_id] *
                         ~attn_mask[layer_id].repeat_interleave(128).to(device)).to(output_weight) @ output_weight.T
                    )
                    if args.wanda_sp:
                        output_bias = torch.zeros_like(output_bias)
                    bias[name][pruning_ratio] = output_bias.cpu().numpy()
                else:
                    print(f'No such param: {name}, please check the model')
                    sys.exit()
            else:
                print(f'No such param: {name}, please check the model')
                sys.exit()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()
    return mask, bias


# ===================== CLI =====================

def parse_args():
    parser = argparse.ArgumentParser(description="Standalone FLAP pruning / importance")

    parser.add_argument("--model_name_or_path", type=str, required=True,
                        help="HF 格式模型路径（本地或 hub 名称）")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="输出目录")

    parser.add_argument("--calib_dataset", type=str, default="mmlu",
                        choices=["wikitext2", "c4", "mmlu"],
                        help="用于 FLAP 校准的重要性数据集")
    parser.add_argument("--data_dir", type=str, default=None,
                        help="本地数据根目录（用于 mmlu 或本地 c4），例如 /seu_nvme/ogai/datasets")

    parser.add_argument("--num_calib_sample", type=int, default=20,
                        help="校准样本数量（按序列条数计）")
    parser.add_argument("--seqlen", type=int, default=512,
                        help="序列长度")

    parser.add_argument("--metrics", type=str, default="WIFV",
                        choices=["IFV", "WIFV", "WIFN"],
                        help="FLAP 中的 metrics 类型")
    parser.add_argument("--structure", type=str, default="AL-AM",
                        choices=["UL-UM", "UL-MM", "AL-MM", "AL-AM"],
                        help="FLAP 中的结构选项")

    parser.add_argument("--pruning_ratio", type=float, default=0.5,
                        help="用于 UL-* 结构的 pruning_ratio")
    parser.add_argument("--remove_heads", type=int, default=0,
                        help="用于 UL-MM / AL-MM 的头数参数")

    parser.add_argument("--wanda_sp", action="store_true",
                        help="使用 wanda_sp 模式")
    parser.add_argument("--only_importance", action="store_true",
                        help="只输出 importance（attn/mlp_metric + baseline），不算 mask/bias")

    parser.add_argument("--fp16", action="store_true",
                        help="用 fp16 加载模型")
    parser.add_argument("--bf16", action="store_true",
                        help="用 bf16 加载模型（如支持）")
    parser.add_argument("--seed", type=int, default=42,
                        help="随机种子")

    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"[INFO] Using device: {device}")

    # 加载模型
    dtype = torch.float32
    if args.fp16 and device.type == "cuda":
        dtype = torch.float16
    elif args.bf16 and device.type == "cuda":
        dtype = torch.bfloat16

    print(f"[INFO] Loading model from {args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=dtype,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    # 一般右侧 padding，和生成任务更兼容
    if tokenizer.padding_side != "right":
        tokenizer.padding_side = "right"

    class SimpleArgs:
        pass

    s = SimpleArgs()
    s.num_calib_sample = args.num_calib_sample
    s.seqlen = args.seqlen
    s.metrics = args.metrics
    s.structure = args.structure
    s.pruning_ratio = args.pruning_ratio
    s.remove_heads = args.remove_heads
    s.wanda_sp = args.wanda_sp
    s.only_importance = args.only_importance
    s.calib_dataset = args.calib_dataset
    s.data_dir = args.data_dir
    s.unstr = False  # 原代码里有这个参数，这里保留占位

    print("[INFO] Start FLAP computation ...")
    if s.only_importance:
        importance = prune_flap(s, model, tokenizer, device=device)
        np.save(os.path.join(args.output_dir, "attn_metric.npy"), importance["attn_metric"])
        np.save(os.path.join(args.output_dir, "mlp_metric.npy"), importance["mlp_metric"])
        np.save(os.path.join(args.output_dir, "attn_baseline.npy"),
                np.array(importance["attn_baseline"], dtype=object))
        np.save(os.path.join(args.output_dir, "mlp_baseline.npy"),
                np.array(importance["mlp_baseline"], dtype=object))
        print("[DONE] Saved importance to attn_metric.npy / mlp_metric.npy (+ baselines)")
    else:
        mask, bias = prune_flap(s, model, tokenizer, device=device)
        np.save(os.path.join(args.output_dir, "width_mask.npy"), mask, allow_pickle=True)
        np.save(os.path.join(args.output_dir, "bias.npy"), bias, allow_pickle=True)
        print("[DONE] Saved width_mask.npy and bias.npy")

    print("[ALL DONE]")


if __name__ == "__main__":
    main()
