#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
FLAP width-only pruning baseline (no iterative depth pruning, no TALE loop)

功能：
- 加载 HF CausalLM（假定 LLaMA 风格：model.model.layers[i].self_attn / mlp）
- 用 FLAP 的 WIFV / IFV / WIFN 逻辑计算一次每层 attn.o_proj / mlp.down_proj 的通道重要性
- 使用官方 FLAP 的 AL-AM 策略（attn-MLP 联合、按参数量加权）选出需要保留的 heads & MLP neurons
- 直接将被裁剪的 heads/neuron 对应权重置零（structured channel pruning）
- 保存：
    - 剪枝后的 HF checkpoint（weights 已经置 0，结构不变）
    - flap_mask.json：记录每一层哪些 head/neuron 被保留（1）或裁剪（0）

用法示例：
python flap_width_baseline.py \
  --model_name_or_path /path/to/Meta-Llama-3.1-8B-Instruct \
  --output_dir /path/to/output/flap_width_50 \
  --calib_dataset mmlu \
  --data_dir /path/to/data_root \
  --num_calib_sample 32 \
  --seqlen 512 \
  --metrics WIFV \
  --target_keep_ratio 0.5 \
  --fp16
"""

import os
import argparse
import random
import json
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F  # noqa: F401

from datasets import load_dataset, Dataset
from tqdm import tqdm

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)

import bitsandbytes as bnb
from peft.tuners.lora.bnb import Linear4bit, Linear8bitLt
from peft.utils.integrations import dequantize_bnb_weight


# ============================================================
# ===================== 数据加载部分 ==========================
# ============================================================

class TokenizerWrapper:
    def __init__(self, input_ids: torch.Tensor):
        self.input_ids = input_ids


def get_wikitext2(nsamples, seed, seqlen, tokenizer):
    traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
    testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

    trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
    testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')  # noqa: F841

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))
    return trainloader, TokenizerWrapper(testenc.input_ids[:, :seqlen])


def get_c4(nsamples, seed, seqlen, tokenizer, data_dir=None):
    if data_dir is None:
        traindata = load_dataset('allenai/c4', 'allenai--c4', split='train')
        valdata = load_dataset('allenai/c4', 'allenai--c4', split='validation')
    else:
        train_files = {'train': 'en/c4-train.00000-of-01024.json.gz'}
        val_files = {'validation': 'en/c4-validation.00000-of-00008.json.gz'}
        traindata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=train_files,
            split='train'
        )
        valdata = load_dataset(
            os.path.join(data_dir, 'allenai--c4'),
            data_files=val_files,
            split='validation'
        )

    random.seed(seed)
    trainloader = []
    for _ in range(nsamples):
        while True:
            i = random.randint(0, len(traindata) - 1)
            trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
            if trainenc.input_ids.shape[1] > seqlen:
                break
        i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
        j = i + seqlen
        inp = trainenc.input_ids[:, i:j]
        tar = inp.clone()
        tar[:, :-1] = -100
        trainloader.append((inp, tar))

    valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
    valenc = valenc.input_ids[:, :(256 * seqlen)]
    valenc = TokenizerWrapper(valenc)
    return trainloader, valenc


def format_mmlu_example(dataset, include_answer: bool = True):
    choices = ["A", "B", "C", "D"]
    data = []
    for ques, cho, ans in zip(dataset["question"], dataset["choices"], dataset["answer"]):
        prompt = ques
        for i in range(len(cho)):
            prompt += "\n{}. {}".format(choices[i], cho[i])
        prompt += "\nAnswer:"
        if include_answer:
            prompt += " {}. {}\n\n".format(choices[ans], choices[ans])
        data.append(prompt)
    return data


def get_mmlu_trainenc(
    data_dir: str,
    seed: int,
    nsamples: int,
    tokenizer,
    seqlen: int = 512,
    num_tasks: int = None,
):
    from tqdm import tqdm as tqdm_local
    subclass = [
        'abstract_algebra', 'anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge',
        'college_biology', 'college_chemistry', 'college_computer_science', 'college_mathematics',
        'college_medicine', 'college_physics', 'computer_security', 'conceptual_physics',
        'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic',
        'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science',
        'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics',
        'high_school_macroeconomics', 'high_school_mathematics', 'high_school_microeconomics',
        'high_school_physics', 'high_school_psychology', 'high_school_statistics',
        'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality',
        'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management',
        'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios',
        'nutrition', 'philosophy', 'prehistory', 'professional_accounting', 'professional_law',
        'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies',
        'sociology', 'us_foreign_policy', 'virology', 'world_religions'
    ]
    random.seed(seed)
    if num_tasks is not None:
        random.shuffle(subclass)
        subclass = subclass[:num_tasks]
    keys = ["question", "subject", "choices", "answer"]
    subnum = nsamples // len(subclass)
    extra = nsamples % len(subclass)
    num_list = [subnum + 1] * extra + [subnum] * (len(subclass) - extra)
    random.shuffle(num_list)

    traindata = {key: [] for key in keys}
    mmlu_path = os.path.join(data_dir, "mmlu")

    for num, classsname in tqdm_local(
        zip(num_list, subclass),
        total=len(subclass),
        desc="Loading the subclass in MMLU"
    ):
        if num == 0:
            continue
        try:
            data = load_dataset(mmlu_path, classsname, split="train").shuffle(seed=seed)[:num]
        except Exception:
            data = load_dataset(mmlu_path, classsname, split="validation").shuffle(seed=seed)[:num]
        for key in keys:
            traindata[key].extend(data[key])

    dataset = Dataset.from_dict(traindata)
    all_data = format_mmlu_example(dataset, include_answer=False)

    trainenc = tokenizer(
        all_data,
        return_tensors='pt',
        max_length=seqlen,
        padding='max_length',
        truncation=True,
    )
    trainenc["input_ids"] = trainenc["input_ids"].reshape(1, -1)
    trainenc["attention_mask"] = trainenc["attention_mask"].reshape(1, -1)
    return trainenc


def get_mmlu_loader(
    data_dir: str,
    nsamples: int,
    seed: int,
    seqlen: int,
    tokenizer,
):
    enc = get_mmlu_trainenc(
        data_dir=data_dir,
        seed=seed,
        nsamples=nsamples,
        tokenizer=tokenizer,
        seqlen=seqlen,
    )
    ids = enc["input_ids"]
    total_len = ids.shape[1]
    random.seed(seed)
    loader = []
    for _ in range(nsamples):
        if total_len <= seqlen + 1:
            i = 0
        else:
            i = random.randint(0, total_len - seqlen - 1)
        j = i + seqlen
        inp = ids[:, i:j].clone()
        tar = inp.clone()
        tar[:, :-1] = -100
        loader.append((inp, tar))
    return loader, TokenizerWrapper(ids[:, :seqlen])


def get_loaders(
    name: str,
    nsamples: int = 128,
    seed: int = 0,
    seqlen: int = 2048,
    tokenizer=None,
    data_dir: str = None,
):
    if 'wikitext2' in name:
        return get_wikitext2(nsamples, seed, seqlen, tokenizer)
    if 'c4' in name:
        return get_c4(nsamples, seed, seqlen, tokenizer, data_dir=data_dir)
    if 'mmlu' in name:
        if data_dir is None:
            raise ValueError("Using mmlu as calib dataset requires --data_dir pointing to parent of 'mmlu'.")
        return get_mmlu_loader(data_dir, nsamples, seed, seqlen, tokenizer)
    if 'boolq' in name:
        return 
    raise ValueError(f"Unknown dataset name in get_loaders: {name}")


# ============================================================
# ============== WrappedGPT & BiasGPT（FLAP） =================
# ============================================================

class WrappedGPT:
    def __init__(self, layer, layer_id=0, layer_name="none"):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.rows = layer.out_features
        self.columns = layer.in_features

        self.scaler_row = torch.zeros((self.columns), device=self.dev)
        self.nsamples = 0

        self.layer_id = layer_id
        self.layer_name = layer_name

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        tmp = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt,
                                   bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        self.scaler_row *= self.nsamples / (self.nsamples + tmp)
        self.nsamples += tmp

        inp = inp.type(torch.float32)
        self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples


class BiasGPT:
    def __init__(self, layer, metric):
        self.layer = layer
        self.dev = self.layer.weight.device

        self.out_dim = layer.out_features
        self.in_dim = layer.in_features

        self.type = metric
        self.nsamples = 0

        self.baseline_inp = torch.zeros((self.in_dim), device=self.dev)
        if self.type == "WIFN":
            self.scaler_inp = torch.zeros((self.in_dim), device=self.dev)
        else:
            self.fluc_inp = torch.zeros((self.in_dim), device=self.dev)

    def add_batch(self, inp, out):
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        batch_size = inp.shape[0]
        if isinstance(self.layer, (nn.Linear, Linear4bit, Linear8bitLt)):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        old_baseline_inp = self.baseline_inp
        self.baseline_inp *= self.nsamples / (self.nsamples + batch_size)
        self.baseline_inp += torch.mean(inp, dim=1) / (self.nsamples + batch_size)

        if self.type == "WIFN":
            inp32 = inp.type(torch.float32)
            self.scaler_inp *= self.nsamples / (self.nsamples + batch_size)
            self.scaler_inp += torch.norm(inp32, p=2, dim=1) ** 2 / (self.nsamples + batch_size)
        else:
            if self.nsamples == 0:
                self.fluc_inp = torch.zeros_like(self.baseline_inp)
            else:
                self.fluc_inp *= (self.nsamples - 1) / (self.nsamples + batch_size - 1)
                self.fluc_inp += torch.sum(
                    (inp - self.baseline_inp.unsqueeze(1)) *
                    (inp - old_baseline_inp.unsqueeze(1)),
                    dim=1
                ) / (self.nsamples + batch_size)

        self.nsamples += batch_size

    def free(self):
        self.baseline_inp = None
        if hasattr(self, 'fluc_inp'):
            self.fluc_inp = None
        if hasattr(self, 'scaler_inp'):
            self.scaler_inp = None
        torch.cuda.empty_cache()


# ============================================================
# ===================== 其他工具函数 ==========================
# ============================================================

def find_layers(
    module,
    layers=(nn.Linear, Linear4bit, Linear8bitLt, bnb.nn.Linear4bit, bnb.nn.Linear8bitLt),
    name: str = ''
):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        if 'lora' in name:
            continue
        res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
    return res


def dequantize(layer):
    if hasattr(layer, "get_base_layer"):
        weight = layer.get_base_layer().weight
    else:
        weight = layer.weight
    if hasattr(weight, "quant_state") and weight.quant_state is not None:
        dequant_weight = dequantize_bnb_weight(weight, state=weight.quant_state)
        return dequant_weight
    else:
        return weight


def count_model_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


def get_decoder_layers(model):
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        return model.model.layers
    elif hasattr(model, "model") and hasattr(model.model, "encoder"):
        return model.model.encoder.layers
    else:
        raise RuntimeError("Cannot find decoder layers in model.")


# ============================================================
# ============= FLAP-style importance 计算 ====================
# ============================================================

@torch.no_grad()
def compute_flap_importance(
    model,
    tokenizer,
    device,
    calib_dataset: str,
    data_dir: Optional[str],
    num_calib_sample: int,
    seqlen: int,
    metrics: str = "WIFV",
    wanda_sp: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
    """
    只计算一次 FLAP 通道重要性（不做剪枝）：
        attn_metric: [L, H_attn_input]
        mlp_metric:  [L, H_mlp_input]
    """
    use_cache = getattr(model.config, "use_cache", False)
    model.config.use_cache = False
    model.eval()

    print(f"[FLAP] Loading calibration data from {calib_dataset}")
    dataloader, _ = get_loaders(
        calib_dataset,
        nsamples=num_calib_sample,
        seed=42,
        tokenizer=tokenizer,
        seqlen=seqlen,
        data_dir=data_dir,
    )
    print(f"[FLAP] Calibration dataset ready, #batches={len(dataloader)}")

    all_linears: Dict[str, nn.Module] = find_layers(model)
    target_names = [
        name for name in all_linears.keys()
        if name.endswith("self_attn.o_proj") or name.endswith("mlp.down_proj")
    ]
    if len(target_names) == 0:
        raise RuntimeError("[FLAP] Cannot find 'self_attn.o_proj' or 'mlp.down_proj' linear layers.")

    wrapped_layers: Dict[str, BiasGPT] = {}
    for name in target_names:
        if wanda_sp:
            wrapped_layers[name] = WrappedGPT(all_linears[name])
        else:
            wrapped_layers[name] = BiasGPT(all_linears[name], metrics)

    handles = []

    def make_hook(name):
        def hook(module, inp, out):
            x_in = inp[0] if isinstance(inp, (tuple, list)) else inp
            wrapped_layers[name].add_batch(x_in.detach(), out.detach())
        return hook

    for name in target_names:
        handles.append(all_linears[name].register_forward_hook(make_hook(name)))

    print("[FLAP] Running calibration forward passes for importance ...")
    for idx, (inp, _) in enumerate(tqdm(dataloader, desc="FLAP calib")):
        if idx >= num_calib_sample:
            break
        model(inp.to(device))

    for h in handles:
        h.remove()

    # 取得 decoder layers 引用
    layers = get_decoder_layers(model)
    num_layers = len(layers)

    def metric_IFV(module_name: str):
        return wrapped_layers[module_name].fluc_inp

    def metric_WIFV(module_name: str):
        layer_mod = all_linears[module_name]
        return wrapped_layers[module_name].fluc_inp * torch.sum(dequantize(layer_mod).data.pow(2), dim=0)

    def metric_WIFN(module_name: str):
        layer_mod = all_linears[module_name]
        return (
            torch.abs(dequantize(layer_mod).data) *
            torch.sqrt(wrapped_layers[module_name].scaler_inp.reshape((1, -1)))
        ).mean(axis=0)

    metrics_dict = {
        "IFV": metric_IFV,
        "WIFV": metric_WIFV,
        "WIFN": metric_WIFN,
    }

    def get_module_name(layer_idx: int, suffix: str) -> str:
        pattern = f"layers.{layer_idx}.{suffix}"
        for n in target_names:
            if pattern in n:
                return n
        raise KeyError(f"[FLAP] Cannot find module with pattern '{pattern}' in target_names.")

    attn_metric_list, mlp_metric_list = [], []
    attn_baseline_inp_list, mlp_baseline_inp_list = [], []

    for i in tqdm(range(num_layers), desc="[FLAP] Computing metric per layer"):
        attn_name = get_module_name(i, "self_attn.o_proj")
        mlp_name = get_module_name(i, "mlp.down_proj")

        # ATTENTION
        if not wanda_sp:
            W_attn = metrics_dict[metrics](attn_name) ** 2
            attn_baseline = wrapped_layers[attn_name].baseline_inp.type(torch.float16)
        else:
            weight_attn = dequantize(all_linears[attn_name])
            W_attn = (
                torch.abs(weight_attn.data) *
                torch.sqrt(wrapped_layers[attn_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            attn_baseline = wrapped_layers[attn_name].scaler_row.type(torch.float16)

        attn_metric_list.append(W_attn.cpu())
        attn_baseline_inp_list.append(attn_baseline.cpu())

        # MLP
        if not wanda_sp:
            W_mlp = metrics_dict[metrics](mlp_name)
            mlp_baseline = wrapped_layers[mlp_name].baseline_inp.type(torch.float16)
        else:
            weight_mlp = dequantize(all_linears[mlp_name])
            W_mlp = (
                torch.abs(weight_mlp.data) *
                torch.sqrt(wrapped_layers[mlp_name].scaler_row.reshape((1, -1)))
            ).mean(axis=0)
            mlp_baseline = wrapped_layers[mlp_name].scaler_row.type(torch.float16)

        mlp_metric_list.append(W_mlp.cpu())
        mlp_baseline_inp_list.append(mlp_baseline.cpu())

        if hasattr(wrapped_layers[attn_name], 'free'):
            wrapped_layers[attn_name].free()
        if hasattr(wrapped_layers[mlp_name], 'free'):
            wrapped_layers[mlp_name].free()

    model.config.use_cache = use_cache
    torch.cuda.empty_cache()

    attn_metric = torch.stack(attn_metric_list)   # [L, H_attn]
    mlp_metric = torch.stack(mlp_metric_list)     # [L, H_mlp]

    return attn_metric, mlp_metric, attn_baseline_inp_list, mlp_baseline_inp_list


# ============================================================
# =============== 应用 width 剪枝：head / neuron ==============
# ============================================================

def apply_head_prune(layer, head_indices, hidden_size, num_heads):
    """
    剪掉若干 attention heads：
    - q_proj/k_proj/v_proj: 对应 head 的行置零
    - o_proj: 对应 head 的输入列置零
    """
    head_dim = hidden_size // num_heads
    with torch.no_grad():
        q = layer.self_attn.q_proj
        k = layer.self_attn.k_proj
        v = layer.self_attn.v_proj
        o = layer.self_attn.o_proj

        for h in head_indices:
            if h < 0 or h >= num_heads:
                continue
            start = h * head_dim
            end = (h + 1) * head_dim
            # q/k/v: [out_features, in_features]
            q.weight.data[start:end, :] = 0
            k.weight.data[start:end, :] = 0
            v.weight.data[start:end, :] = 0
            # o: [out_features, in_features]，heads 在 in_features 维度
            o.weight.data[:, start:end] = 0


def apply_mlp_neuron_prune(layer, neuron_indices, hidden_size, mlp_dim):
    """
    剪掉若干 FFN neurons：
    - gate_proj / up_proj: 对应行置零
    - down_proj: 对应列置零
    """
    with torch.no_grad():
        gate = layer.mlp.gate_proj
        up = layer.mlp.up_proj
        down = layer.mlp.down_proj
        for n in neuron_indices:
            if n < 0 or n >= mlp_dim:
                continue
            gate.weight.data[n, :] = 0
            up.weight.data[n, :] = 0
            down.weight.data[:, n] = 0


# ============================================================
# ========= 计算 FLAP AL-AM 风格的 head/neuron mask ==========
# ============================================================

def standardize_per_layer(x: torch.Tensor) -> torch.Tensor:
    """
    x: [L, D]，对每一层做 z-score 标准化。
    """
    mean = x.mean(dim=1, keepdim=True)
    std = x.std(dim=1, keepdim=True)
    std = std + 1e-6
    return (x - mean) / std


def compute_flap_masks_alam(
    attn_metric: torch.Tensor,  # [L, H_attn]
    mlp_metric: torch.Tensor,   # [L, H_mlp]
    model,
    target_keep_ratio: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    参考官方 FLAP 的 AL-AM 实现：
    - attn_metric reshape 成 [L, num_heads]（每个 head 聚合 128 维）
    - mlp_metric 保持 [L, mlp_dim]
    - 拼接后，用 compression_weight（attn entries 权重 512/3）做“按参数量加权”的排序，
      找到让累计 weight 接近 target_keep_ratio * total_weight 的阈值
    - 输出:
        attn_keep_mask: [L, num_heads] (bool)
        mlp_keep_mask:  [L, mlp_dim]  (bool)
    """
    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    num_heads = model.config.num_attention_heads
    head_dim = hidden_size // num_heads

    assert attn_metric.shape[0] == num_layers, "attn_metric L != num_layers"
    assert attn_metric.shape[1] == hidden_size, "attn_metric second dim != hidden_size"
    assert mlp_metric.shape[0] == num_layers, "mlp_metric L != num_layers"

    # 1) ATTENTION: [L, hidden_size] -> [L, num_heads]（聚合 head_dim）
    attn_metric_heads = attn_metric.reshape(num_layers, num_heads, head_dim).sum(dim=2)

    # 2) 标准化
    attn_std = standardize_per_layer(attn_metric_heads)   # [L, num_heads]
    mlp_std = standardize_per_layer(mlp_metric)           # [L, mlp_dim]

    # 3) 展平 & 组合
    attn_flat = attn_std.reshape(-1)          # 大小 = num_layers * num_heads
    mlp_flat = mlp_std.reshape(-1)            # 大小 = num_layers * mlp_dim
    prune_metric = torch.cat([attn_flat, mlp_flat], dim=0)    # 越大越重要

    # 4) compression weight：官方实现用 512/3 近似 head 与 neuron 参数量比例
    compression_weight = torch.ones_like(prune_metric)
    num_attn_entries = attn_flat.numel()
    compression_weight[:num_attn_entries] = 512.0 / 3.0   # head 的权重

    total_weight = compression_weight.sum()
    target_weight_keep = total_weight * float(target_keep_ratio)

    # 5) 按分数从大到小排序，累积 compression_weight，找到最接近 target_weight_keep 的位置
    sorted_scores, sorted_idx = torch.sort(prune_metric, descending=True)
    sorted_weight = compression_weight[sorted_idx]
    cum_weight = torch.cumsum(sorted_weight, dim=0)

    # 找到 |cum_weight - target| 最小的位置
    diff = torch.abs(cum_weight - target_weight_keep)
    best_pos = torch.argmin(diff).item()
    threshold = sorted_scores[best_pos]

    # 6) 根据阈值生成 keep mask（True = 保留，False = 剪掉）
    keep_mask_flat = prune_metric >= threshold

    attn_keep_flat = keep_mask_flat[:num_attn_entries]
    mlp_keep_flat = keep_mask_flat[num_attn_entries:]

    attn_keep = attn_keep_flat.reshape(num_layers, num_heads)
    mlp_keep = mlp_keep_flat.reshape(num_layers, mlp_metric.shape[1])

    print(f"[FLAP] Target keep ratio (weight-level): {target_keep_ratio:.4f}")
    print(f"[FLAP] Actual keep ratio (weight-level): {compression_weight[keep_mask_flat].sum() / total_weight:.4f}")

    return attn_keep.bool(), mlp_keep.bool()


# ============================================================
# =================== 一次性 FLAP 宽度剪枝 ====================
# ============================================================

def run_flap_width_pruning(args, model, tokenizer, device):
    """
    一次性 FLAP 宽度剪枝 baseline：
    - 只做 width 剪枝（heads + mlp neurons）
    - 不改网络结构，只把对应权重置 0
    - 输出 mask JSON
    """
    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    num_heads = model.config.num_attention_heads
    mlp_dim = model.config.intermediate_size

    print(f"[FLAP] Model: {num_layers} layers, hidden={hidden_size}, mlp_dim={mlp_dim}, heads={num_heads}")
    total_params = count_model_params(model)
    print(f"[FLAP] Total params = {total_params/1e6:.2f}M")

    # 1) 计算 FLAP 通道重要性
    attn_metric, mlp_metric, attn_base, mlp_base = compute_flap_importance(
        model=model,
        tokenizer=tokenizer,
        device=device,
        calib_dataset=args.calib_dataset,
        data_dir=args.data_dir,
        num_calib_sample=args.num_calib_sample,
        seqlen=args.seqlen,
        metrics=args.metrics,
        wanda_sp=args.wanda_sp,
    )
    print("[FLAP] Importance computed.")

    # 2) 根据 AL-AM 策略，得到每层的 head/neuron keep mask
    attn_keep_mask, mlp_keep_mask = compute_flap_masks_alam(
        attn_metric=attn_metric,
        mlp_metric=mlp_metric,
        model=model,
        target_keep_ratio=args.target_keep_ratio,
    )

    # 3) 应用剪枝（置 0）并构造 JSON mask
    mask_json = {
        "attn_heads": {},
        "mlp_neurons": {},
        "meta": {
            "prune_type": "flap_width_only",
            "metric": args.metrics,
            "target_keep_ratio": args.target_keep_ratio,
            "num_layers": num_layers,
            "hidden_size": hidden_size,
            "num_heads": num_heads,
            "mlp_dim": mlp_dim,
        }
    }

    for l in range(num_layers):
        layer = layers[l]
        head_keep = attn_keep_mask[l].cpu().numpy().astype(int).tolist()
        neuron_keep = mlp_keep_mask[l].cpu().numpy().astype(int).tolist()

        mask_json["attn_heads"][f"layer_{l}"] = head_keep
        mask_json["mlp_neurons"][f"layer_{l}"] = neuron_keep

        pruned_heads = [h for h, v in enumerate(head_keep) if v == 0]
        pruned_neurons = [n for n, v in enumerate(neuron_keep) if v == 0]

        print(f"[FLAP] Layer {l}: keep_heads={sum(head_keep)}/{len(head_keep)}, "
              f"keep_neurons={sum(neuron_keep)}/{len(neuron_keep)}")

        apply_head_prune(layer, pruned_heads, hidden_size, num_heads)
        apply_mlp_neuron_prune(layer, pruned_neurons, hidden_size, mlp_dim)

    # 4) 简单统计最终 sparsity（只统计 Linear 权重）
    zero_cnt = 0
    total_cnt = 0
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, Linear4bit, Linear8bitLt,
                               bnb.nn.Linear4bit, bnb.nn.Linear8bitLt)):
            W = module.weight.data
            zero_cnt += (W == 0).sum().item()
            total_cnt += W.numel()
    weight_sparsity = zero_cnt / max(total_cnt, 1)
    print(f"[FLAP] Final linear-weight sparsity = {weight_sparsity*100:.2f}%")

    # 5) 保存模型 & mask
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"[FLAP] Saving pruned model to {args.output_dir}")
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)

    mask_path = os.path.join(args.output_dir, "flap_mask.json")
    with open(mask_path, "w", encoding="utf-8") as f:
        json.dump(mask_json, f, indent=2)
    print(f"[FLAP] Saved mask JSON to {mask_path}")

    return


# ============================================================
# ========================= CLI 部分 =========================
# ============================================================

def parse_args():
    parser = argparse.ArgumentParser(description="FLAP width-only pruning baseline (no iterative TALE)")

    parser.add_argument("--model_name_or_path", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)

    parser.add_argument("--calib_dataset", type=str, default="mmlu",
                        choices=["wikitext2", "c4", "mmlu"])
    parser.add_argument("--data_dir", type=str, default=None,
                        help="如果使用 mmlu，data_dir 应该是包含 mmlu 子目录的父目录")

    parser.add_argument("--num_calib_sample", type=int, default=20,
                        help="用于 FLAP 的校准样本数（序列条数）")
    parser.add_argument("--seqlen", type=int, default=512)

    parser.add_argument("--metrics", type=str, default="WIFV",
                        choices=["IFV", "WIFV", "WIFN"],
                        help="FLAP 通道重要性度量方式")
    parser.add_argument("--wanda_sp", action="store_true",
                        help="若为 True，使用 WANDA-style scaler_row 而非 BiasGPT")

    # 这一项决定“保留多少参数量对应的通道”
    # 示例：0.5 = 大约保留 50% 参数量对应的 heads+neurons
    parser.add_argument("--target_keep_ratio", type=float, default=0.5)

    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    return parser.parse_args()


def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"[FLAP] Using device: {device}")

    # 加载模型
    dtype = torch.float32
    if args.fp16 and device.type == "cuda":
        dtype = torch.float16
    elif args.bf16 and device.type == "cuda":
        dtype = torch.bfloat16

    print(f"[FLAP] Loading model from {args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=dtype,
        device_map=None,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    run_flap_width_pruning(args, model, tokenizer, device)

    print("[FLAP] Done.")


if __name__ == "__main__":
    main()
