import argparse
import csv
import os

import torch
import torch.nn.functional as F
from dataset import get_examples
from utils import get_model, set_seed


def register_bi_hooks(model):
    """
    为每个 transformer block 注册 forward hook，
    收集每层的 (1 - cosine_similarity(input, output)) 的批次均值。
    返回 (handles, block_info)：
      - handles: 之后用于移除 hooks
      - block_info: dict[str, list[float]] 记录各层每个批次的 BI 分数
    """
    block_info = {}
    handles = []

    # LLaMA: model.model.layers 是一组 LlamaDecoderLayer
    for idx, layer in enumerate(model.model.layers):
        block_name = f"model.layers.{idx}"

        def make_hook(name):
            def hook(module, inputs, output):
                # inputs: tuple，第一个元素是 hidden_states [B, S, D]
                # output: 可能是 Tensor 或 (Tensor, ...)；取 hidden_states
                inp = inputs[0]
                out = output[0] if isinstance(output, tuple) else output

                # 只做数值计算，不保留梯度
                inp = inp.detach()
                out = out.detach()

                # 逐 token 计算余弦相似度，然后对 B、S 取均值
                # 形状 [B, S]
                cos = F.cosine_similarity(inp, out, dim=-1)
                bi = (1.0 - cos).mean().item()

                if name not in block_info:
                    block_info[name] = []
                block_info[name].append(bi)

            return hook

        h = layer.register_forward_hook(make_hook(block_name))
        handles.append(h)

    return handles, block_info


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--base_model",
        type=str,
        default="baffo32/decapoda-research-llama-7B-hf",
        help="base model name",
    )
    parser.add_argument(
        "--tokenizer", type=str, default=None, help="if None, base model name is used"
    )
    parser.add_argument(
        "--model_type",
        type=str,
        default="pretrain",
        choices=["pretrain", "pruneLLM", "tune_pruneLLM"],
    )
    parser.add_argument("--ckpt", type=str, default=None)
    parser.add_argument("--lora_ckpt", type=str, default=None)
    parser.add_argument("--device", type=str, default="cpu", help="device")
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--max_seq_len", type=int, default=128)
    parser.add_argument("--num_calib_data", type=int, default=10)
    parser.add_argument("--batch_size", type=int, default=10)
    parser.add_argument(
        "--output_dir",
        type=str,
        default="output_block_sensitivity/llama-1-7b/bi_n10",
    )
    # 以下几个参数在 BI 方法中不会用到，但保留以兼容命令行
    parser.add_argument("--norm_power", type=int, default=1, help="unused for BI")
    parser.add_argument(
        "--weight_reduction", type=str, default="sum", help="unused for BI"
    )
    parser.add_argument(
        "--block_reduction", type=str, default="mean", help="sum, mean, max, prod"
    )
    parser.add_argument(
        "--fix_decapoda_config",
        default=False,
        action="store_true",
        help="fix tokenizer config of baffo32/decapoda-research-llama-7B-hf",
    )
    parser.add_argument(
        "--add_bos_to_every",
        default=False,
        action="store_true",
        help="whether to add BOS token to every sample in calibration dataset",
    )
    parser.add_argument("--use_bfloat", default=False, action="store_true")
    args = parser.parse_args()

    set_seed(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    block_reduction = args.block_reduction
    # 输出文件（保留与原脚本相同的命名，便于后续复用）
    result_csv_block = os.path.join(args.output_dir, "block_score_all.csv")
    result_csv_block_detail = os.path.join(args.output_dir, "block_score_detail.csv")
    result_csv_block_sort = os.path.join(args.output_dir, "block_score_sorted.csv")
    block_order_path = os.path.join(args.output_dir, "block_order.csv")

    if not os.path.exists(block_order_path):
        # 准备模型与数据
        model, tokenizer, description = get_model(
            base_model=args.base_model,
            ckpt=args.ckpt,
            lora_ckpt=args.lora_ckpt,
            tokenizer=args.tokenizer,
            model_type=args.model_type,
            device=args.device,
            fix_decapoda_config=args.fix_decapoda_config,
            use_bfloat=args.use_bfloat,
        )
        model.eval()

        example_prompts = get_examples(
            dataset="bookcorpus",
            tokenizer=tokenizer,
            n_samples=args.num_calib_data,
            seq_len=args.max_seq_len,
            field_name="text",
            add_bos_to_every=args.add_bos_to_every,
        ).to(args.device)

        # 注册 hooks，收集每层 (1 - cos) 的均值（按批次）
        handles, block_info = register_bi_hooks(model)

        print("Do forward to collect BI (1 - cosine) scores per block")
        with torch.no_grad():
            for i in range(0, example_prompts.size(0), args.batch_size):
                input_ids = example_prompts[i : i + args.batch_size]
                # 只需前向，不需要 labels 与 loss
                _ = model(input_ids)

        # 移除 hooks
        for h in handles:
            h.remove()

        # 计算并保存 block-level 重要性
        block_info_summary = {}
        with open(result_csv_block, "w", newline="") as logfile, open(
            result_csv_block_detail, "w", newline=""
        ) as logfile_detail:
            logwriter = csv.writer(logfile)
            logwriter.writerow(["block_name", "block_score"])
            logwriter_detail = csv.writer(logfile_detail)
            logwriter_detail.writerow(["block_name", "all_batch_bi_scores"])

            for k, v in block_info.items():
                # v: List[float]，每个 batch 的 BI 分数
                logwriter_detail.writerow([k] + v)

                bi_tensor = torch.tensor(v)
                if block_reduction == "sum":
                    block_imp = bi_tensor.sum()
                elif block_reduction == "mean":
                    block_imp = bi_tensor.mean()
                elif block_reduction == "max":
                    block_imp = bi_tensor.max()
                elif block_reduction == "prod":
                    block_imp = torch.prod(bi_tensor)
                else:
                    raise NotImplementedError(f"Unknown block_reduction: {block_reduction}")

                block_imp = float(block_imp.item())
                logwriter.writerow([k, block_imp])
                block_info_summary[k] = block_imp

        # 只保留真正的层（这里我们只给 model.layers.* 注册了 hook，本身就不会包含 norm/lm_head）
        sorted_items = sorted(block_info_summary.items(), key=lambda x: x[1])  # 越小越不重要
        block_order = []
        with open(result_csv_block_sort, "w", newline="") as logfile:
            logwriter = csv.writer(logfile)
            logwriter.writerow(["rank", "block_name", "block_score", "block_index"])
            for rank, (key, value) in enumerate(sorted_items, start=1):
                layer_idx = key.split(".")[-1]
                logwriter.writerow([rank, key, value, layer_idx])
                print([rank, key, value, layer_idx])
                block_order.append(int(layer_idx))

        with open(block_order_path, "w", newline="") as logfile_order:
            logwriter_order = csv.writer(logfile_order)
            logwriter_order.writerow(block_order)

        print(f"=== block order saved: {block_order_path}")
        print(block_order)
        print(f"len: {len(block_order)}")

    else:
        print(f"use the precomputed results at {block_order_path}")
