#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Iterative FLAP width-only pruning (no depth pruning, no finetune in-between).

功能：
- 从一个 HF CausalLM（假定 LLaMA 风格：model.model.layers[i].self_attn / mlp）开始
- 以小步长（如 2%）迭代增加“全局总剪枝率”：
    第 k 轮目标剪枝率 prune_ratio_total = k * step
    对应目标保留率 keep_ratio = 1 - prune_ratio_total
- 每轮都在当前“已剪枝模型（部分权重为 0）”上重新计算 FLAP 通道重要性
- 使用 AL-AM 策略选出需要保留的 heads & MLP neurons，并叠加到“累计 mask”上
    - 历史上已经被剪掉的通道在后续轮次永远不能重新激活
- 每轮更新模型权重（对应通道置 0，但结构不变）
- 当“实际权重剪枝率”第一次达到 3.xx%、4.xx%、5.xx% 等整数百分比 p.xx% 时，保存：
    - 当前模型 checkpoint（权重为 0）
    - 对应 flap_mask.json（记录累计 mask）

示例用法：

python iterative_flap_prun.py \
  --model_name_or_path /path/to/Meta-Llama-3.1-8B-Instruct \
  --output_root /path/to/output/flap_iter_llama3 \
  --calib_dataset mmlu \
  --data_dir /path/to/data_root \
  --num_calib_sample 20 \
  --seqlen 512 \
  --metrics WIFV \
  --wanda_sp \
  --start_keep_ratio 1.0 \
  --end_keep_ratio 0.5 \
  --keep_ratio_step 0.02 \
  --min_save_prune_percent 3 \
  --fp16
"""

import os
import json
import argparse
from typing import Optional, Tuple, Dict

import torch
import torch.nn as nn
import bitsandbytes as bnb

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    set_seed,
)

# 从你现有的单次 FLAP 脚本中 import 关键函数
from flap_prun import (
    compute_flap_importance,
    apply_head_prune,
    apply_mlp_neuron_prune,
    get_decoder_layers,
    count_model_params,
    standardize_per_layer,  # 如果没有暴露，可以把 compute_flap_masks_alam 的逻辑复制一份
)

def compute_flap_masks_alam_iterative(
    attn_metric: torch.Tensor,      # [L, hidden_size]
    mlp_metric: torch.Tensor,       # [L, mlp_dim]
    model,
    target_keep_ratio: float,
    prev_attn_keep: Optional[torch.Tensor],  # [L, num_heads] bool
    prev_mlp_keep: Optional[torch.Tensor],   # [L, mlp_dim]  bool
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    带“累计 mask”的 AL-AM：
    - 对 attn_metric reshape 成 [L, num_heads]（每个 head 聚合 head_dim）
    - mlp_metric 保持 [L, mlp_dim]
    - 使用 compression_weight（head entries 权重 512/3）做“按参数量加权”的排序
    - 生成新的 keep mask 后，再与 prev_keep 做 AND，保证历史剪掉的通道不会复活
    """
    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    num_heads = model.config.num_attention_heads
    head_dim = hidden_size // num_heads

    assert attn_metric.shape[0] == num_layers, "attn_metric L != num_layers"
    assert attn_metric.shape[1] == hidden_size, "attn_metric second dim != hidden_size"
    assert mlp_metric.shape[0] == num_layers, "mlp_metric L != num_layers"

    # 1) ATTENTION: [L, hidden_size] -> [L, num_heads]（聚合 head_dim）
    attn_metric_heads = attn_metric.reshape(num_layers, num_heads, head_dim).sum(dim=2)

    # 2) 对每一层做标准化（z-score）
    attn_std = standardize_per_layer(attn_metric_heads)   # [L, num_heads]
    mlp_std = standardize_per_layer(mlp_metric)           # [L, mlp_dim]

    # 3) 展平 & 组合
    attn_flat = attn_std.reshape(-1)      # 大小 = num_layers * num_heads
    mlp_flat = mlp_std.reshape(-1)        # 大小 = num_layers * mlp_dim
    prune_metric = torch.cat([attn_flat, mlp_flat], dim=0)  # 越大越重要

    # 4) compression weight：官方实现用 512/3 近似 head 与 neuron 参数量比例
    compression_weight = torch.ones_like(prune_metric)
    num_attn_entries = attn_flat.numel()
    compression_weight[:num_attn_entries] = 512.0 / 3.0   # head 的权重

    total_weight = compression_weight.sum()
    target_weight_keep = total_weight * float(target_keep_ratio)

    # 5) 按分数从大到小排序，累积 compression_weight，找到最接近 target_weight_keep 的位置
    sorted_scores, sorted_idx = torch.sort(prune_metric, descending=True)
    sorted_weight = compression_weight[sorted_idx]
    cum_weight = torch.cumsum(sorted_weight, dim=0)

    diff = torch.abs(cum_weight - target_weight_keep)
    best_pos = torch.argmin(diff).item()
    threshold = sorted_scores[best_pos]

    # 6) 根据阈值生成 keep mask（True = 保留，False = 剪掉）
    keep_mask_flat = prune_metric >= threshold

    attn_keep_flat = keep_mask_flat[:num_attn_entries]
    mlp_keep_flat = keep_mask_flat[num_attn_entries:]

    attn_keep = attn_keep_flat.reshape(num_layers, num_heads)
    mlp_keep = mlp_keep_flat.reshape(num_layers, mlp_metric.shape[1])

    # 7) 若存在累计 mask，则强制历史 0 继续为 0
    if prev_attn_keep is not None:
        attn_keep = attn_keep & prev_attn_keep
    if prev_mlp_keep is not None:
        mlp_keep = mlp_keep & prev_mlp_keep

    print(f"[FLAP-iter] Target keep ratio (weight-level): {target_keep_ratio:.4f}")
    print(f"[FLAP-iter] NOTE: historical pruned entries are forced to remain pruned.")

    return attn_keep.bool(), mlp_keep.bool()


def compute_linear_weight_sparsity(model: nn.Module) -> float:
    """
    统计模型中所有 Linear / 4bit / 8bit 权重的稀疏度（置 0 比例）。
    """
    zero_cnt = 0
    total_cnt = 0
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear,
                               bnb.nn.Linear4bit,
                               bnb.nn.Linear8bitLt)):
            W = module.weight.data
            zero_cnt += (W == 0).sum().item()
            total_cnt += W.numel()
    if total_cnt == 0:
        return 0.0
    return zero_cnt / float(total_cnt)


def build_mask_json(
    attn_keep_mask: torch.Tensor,  # [L, num_heads] bool
    mlp_keep_mask: torch.Tensor,   # [L, mlp_dim]  bool
    model,
    metric: str,
    iter_idx: int,
    target_keep_ratio: float,
    actual_sparsity: float,
) -> Dict:
    """
    根据当前累计 keep mask 构造 flap_mask.json 的内容。
    """
    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    num_heads = model.config.num_attention_heads
    mlp_dim = model.config.intermediate_size

    mask_json = {
        "attn_heads": {},
        "mlp_neurons": {},
        "meta": {
            "prune_type": "flap_width_iterative",
            "metric": metric,
            "iter_index": iter_idx,
            "target_keep_ratio_this_iter": float(target_keep_ratio),
            "actual_linear_weight_sparsity": float(actual_sparsity),
            "num_layers": num_layers,
            "hidden_size": hidden_size,
            "num_heads": num_heads,
            "mlp_dim": mlp_dim,
        }
    }

    for l in range(num_layers):
        head_keep = attn_keep_mask[l].cpu().numpy().astype(int).tolist()
        neuron_keep = mlp_keep_mask[l].cpu().numpy().astype(int).tolist()
        mask_json["attn_heads"][f"layer_{l}"] = head_keep
        mask_json["mlp_neurons"][f"layer_{l}"] = neuron_keep

    return mask_json


def save_checkpoint_with_mask(
    save_dir: str,
    model,
    tokenizer,
    attn_keep_mask: torch.Tensor,
    mlp_keep_mask: torch.Tensor,
    metric: str,
    iter_idx: int,
    target_keep_ratio: float,
    actual_sparsity: float,
):
    """
    在 save_dir 下保存：
    - 模型 checkpoint
    - tokenizer
    - flap_mask.json（累计 mask）
    """
    os.makedirs(save_dir, exist_ok=True)
    print(f"[FLAP-iter] Saving checkpoint to {save_dir}")
    model.save_pretrained(save_dir)
    tokenizer.save_pretrained(save_dir)

    mask_json = build_mask_json(
        attn_keep_mask=attn_keep_mask,
        mlp_keep_mask=mlp_keep_mask,
        model=model,
        metric=metric,
        iter_idx=iter_idx,
        target_keep_ratio=target_keep_ratio,
        actual_sparsity=actual_sparsity,
    )
    mask_path = os.path.join(save_dir, "flap_mask.json")
    with open(mask_path, "w", encoding="utf-8") as f:
        json.dump(mask_json, f, indent=2)
    print(f"[FLAP-iter] Saved mask JSON to {mask_path}")


def run_iterative_flap_width_pruning(args):
    """
    迭代式 FLAP 宽度剪枝主流程。
    """
    set_seed(args.seed)

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"[FLAP-iter] Using device: {device}")

    # 加载模型
    dtype = torch.float32
    if args.fp16 and device.type == "cuda":
        dtype = torch.float16
    elif args.bf16 and device.type == "cuda":
        dtype = torch.bfloat16

    print(f"[FLAP-iter] Loading base model from {args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        torch_dtype=dtype,
        device_map=None,
    ).to(device)

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name_or_path,
        use_fast=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

    layers = get_decoder_layers(model)
    num_layers = len(layers)
    hidden_size = model.config.hidden_size
    num_heads = model.config.num_attention_heads
    mlp_dim = model.config.intermediate_size

    print(f"[FLAP-iter] Model: {num_layers} layers, hidden={hidden_size}, mlp_dim={mlp_dim}, heads={num_heads}")
    total_params = count_model_params(model)
    print(f"[FLAP-iter] Total params = {total_params/1e6:.2f}M")

    os.makedirs(args.output_root, exist_ok=True)

    # 累计 keep mask（迭代过程中不断更新）
    prev_attn_keep = torch.ones((num_layers, num_heads), dtype=torch.bool)
    prev_mlp_keep = torch.ones((num_layers, mlp_dim), dtype=torch.bool)

    # 记录已保存过的“整数剪枝率”（百分比，比如 3, 4, 5,...）
    saved_prune_percent = set()

    # 迭代：从 start_keep_ratio 降到 end_keep_ratio，每次减少 keep_ratio_step
    keep_ratio = float(args.start_keep_ratio)
    iter_idx = 0

    while keep_ratio >= args.end_keep_ratio - 1e-8 and iter_idx < args.max_iters:
        prune_ratio_target = 1.0 - keep_ratio
        print("=" * 80)
        print(f"[FLAP-iter] Iteration {iter_idx} | target_prune_ratio={prune_ratio_target:.4f} "
              f"(keep_ratio={keep_ratio:.4f})")

        # 1) 在当前模型上计算 FLAP 通道重要性
        attn_metric, mlp_metric, _, _ = compute_flap_importance(
            model=model,
            tokenizer=tokenizer,
            device=device,
            calib_dataset=args.calib_dataset,
            data_dir=args.data_dir,
            num_calib_sample=args.num_calib_sample,
            seqlen=args.seqlen,
            metrics=args.metrics,
            wanda_sp=args.wanda_sp,
        )
        print("[FLAP-iter] Importance computed on current pruned model.")

        # 2) 计算本轮目标 keep_ratio 对应的 head/neuron mask，并与历史 mask 叠加
        attn_keep_mask, mlp_keep_mask = compute_flap_masks_alam_iterative(
            attn_metric=attn_metric,
            mlp_metric=mlp_metric,
            model=model,
            target_keep_ratio=keep_ratio,
            prev_attn_keep=prev_attn_keep,
            prev_mlp_keep=prev_mlp_keep,
        )

        # 3) 应用剪枝（置 0）
        print("[FLAP-iter] Applying head/neuron pruning for this iteration ...")
        for l in range(num_layers):
            layer = layers[l]

            head_keep_list = attn_keep_mask[l].cpu().numpy().astype(int).tolist()
            neuron_keep_list = mlp_keep_mask[l].cpu().numpy().astype(int).tolist()

            pruned_heads = [h for h, v in enumerate(head_keep_list) if v == 0]
            pruned_neurons = [n for n, v in enumerate(neuron_keep_list) if v == 0]

            print(f"[FLAP-iter] Layer {l}: keep_heads={sum(head_keep_list)}/{len(head_keep_list)}, "
                  f"keep_neurons={sum(neuron_keep_list)}/{len(neuron_keep_list)}")

            apply_head_prune(layer, pruned_heads, hidden_size, num_heads)
            apply_mlp_neuron_prune(layer, pruned_neurons, hidden_size, mlp_dim)

        # 更新累计 mask
        prev_attn_keep = attn_keep_mask
        prev_mlp_keep = mlp_keep_mask

        # 4) 统计当前真实稀疏度（只看 Linear 权重）
        weight_sparsity = compute_linear_weight_sparsity(model)
        prune_pct_real = weight_sparsity * 100.0
        print(f"[FLAP-iter] Current linear-weight sparsity = {prune_pct_real:.2f}%")

        # 5) 检查是否跨过新的整数剪枝率（例如第一次 >= 3% / 4% / 5% 等）
        cur_int_pct = int(prune_pct_real)  # floor，例如 3.xx -> 3
        min_save_pct = int(args.min_save_prune_percent)

        # 找出当前轮中“尚未保存过”的新整数阈值
        new_thresholds = [
            p for p in range(min_save_pct, cur_int_pct + 1)
            if p not in saved_prune_percent
        ]
        if new_thresholds:
            # 只保存当前轮跨过的最高整数阈值，比如第一次 >= 4.xx 就保存 4%
            p_star = max(new_thresholds)
            saved_prune_percent.add(p_star)
            save_dir = os.path.join(args.output_root, f"prune_{p_star:02d}pct")
            save_checkpoint_with_mask(
                save_dir=save_dir,
                model=model,
                tokenizer=tokenizer,
                attn_keep_mask=prev_attn_keep,
                mlp_keep_mask=prev_mlp_keep,
                metric=args.metrics,
                iter_idx=iter_idx,
                target_keep_ratio=keep_ratio,
                actual_sparsity=weight_sparsity,
            )
        else:
            print("[FLAP-iter] No new integer prune-percent threshold crossed in this iteration.")

        # 6) 下一轮
        iter_idx += 1
        keep_ratio -= args.keep_ratio_step

    print("[FLAP-iter] Iterative pruning finished.")
    print(f"[FLAP-iter] Saved prune-percent checkpoints: {sorted(list(saved_prune_percent))}")


def parse_args():
    parser = argparse.ArgumentParser(description="Iterative FLAP width-only pruning (no finetune).")

    parser.add_argument("--model_name_or_path", type=str, required=True,
                        help="原始（未剪枝）HF CausalLM 模型路径，例如 Llama-3.1-8B-Instruct。")
    parser.add_argument("--output_root", type=str, required=True,
                        help="所有迭代剪枝 checkpoint 的根目录，将在其下创建 prune_XXpct 子目录。")

    parser.add_argument("--calib_dataset", type=str, default="mmlu",
                        choices=["wikitext2", "c4", "mmlu"])
    parser.add_argument("--data_dir", type=str, default=None,
                        help="如果使用 mmlu，data_dir 应该是包含 mmlu 子目录的父目录。")

    parser.add_argument("--num_calib_sample", type=int, default=20,
                        help="用于 FLAP 的校准样本数（序列条数）。")
    parser.add_argument("--seqlen", type=int, default=512)

    parser.add_argument("--metrics", type=str, default="WIFV",
                        choices=["IFV", "WIFV", "WIFN"],
                        help="FLAP 通道重要性度量方式。")
    parser.add_argument("--wanda_sp", action="store_true",
                        help="若为 True，使用 WANDA-style scaler_row 而非 BiasGPT。")

    # 迭代剪枝控制参数
    parser.add_argument("--start_keep_ratio", type=float, default=1.0,
                        help="起始保留率，例如 1.0 表示从 0% 剪枝开始。")
    parser.add_argument("--end_keep_ratio", type=float, default=0.5,
                        help="迭代终止的最小保留率，例如 0.5 表示最多剪到 50%。")
    parser.add_argument("--keep_ratio_step", type=float, default=0.02,
                        help="每轮降低的保留率步长，例如 0.02 对应每轮增加 2% 剪枝率。")
    parser.add_argument("--max_iters", type=int, default=100,
                        help="最大迭代轮数保护。")

    parser.add_argument("--min_save_prune_percent", type=int, default=3,
                        help="第一次达到 >= 该剪枝百分比时开始保存 checkpoint，例如 3 表示从 3.xx%% 开始。")

    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--seed", type=int, default=42)

    return parser.parse_args()


def main():
    args = parse_args()
    run_iterative_flap_width_pruning(args)


if __name__ == "__main__":
    main()
