#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Compute Spearman correlation of neuron importance between full and pruned models.

Inputs:
- full_attn_metric.npy: shape [L_full, D_attn]
- full_mlp_metric.npy:  shape [L_full, D_mlp]
- pruned_attn_metric.npy: shape [L_pruned, D_attn]
- pruned_mlp_metric.npy:  shape [L_pruned, D_mlp]

We also need the list of removed layer indices in the *full* model.
We then map pruned layer index -> full layer index by:
- iterate full layers from 0..L_full-1
- skip layers in removed list
- assign remaining layers in order to pruned layers 0..L_pruned-1
"""

import argparse
import numpy as np
import os


def rankdata(a: np.ndarray) -> np.ndarray:
    """
    Simple rank implementation: strictly increasing values assumed
    (ties are unlikely for float importance values).
    Returns ranks from 0..n-1.
    """
    order = np.argsort(a)
    ranks = np.empty_like(order, dtype=float)
    ranks[order] = np.arange(len(a), dtype=float)
    return ranks


def spearman_corr(x: np.ndarray, y: np.ndarray) -> float:
    """
    Spearman correlation via Pearson corr on ranks.
    Both x and y are 1D arrays of same length.
    """
    assert x.shape == y.shape, "Spearman inputs must have same shape."
    rx = rankdata(x)
    ry = rankdata(y)
    # Pearson correlation on ranks
    vx = rx - rx.mean()
    vy = ry - ry.mean()
    denom = np.sqrt((vx ** 2).sum() * (vy ** 2).sum())
    if denom == 0:
        return 0.0
    return float((vx * vy).sum() / denom)


def build_layer_mapping(num_layers_full: int,
                        removed_layers: list,
                        num_layers_pruned: int):
    """
    Build mapping pruned_layer_idx -> full_layer_idx.

    Strategy:
    - sort removed_layers
    - iterate full_idx from 0..num_layers_full-1
    - if full_idx not in removed_layers, assign to next pruned_idx
    """
    removed_set = set(removed_layers)
    full_to_pruned = {}
    pruned_to_full = {}

    p_idx = 0
    for f_idx in range(num_layers_full):
        if f_idx in removed_set:
            continue
        if p_idx >= num_layers_pruned:
            break
        full_to_pruned[f_idx] = p_idx
        pruned_to_full[p_idx] = f_idx
        p_idx += 1

    if p_idx != num_layers_pruned:
        print(
            f"[WARN] Number of pruned layers ({num_layers_pruned}) does not "
            f"match full_layers - removed_layers ({p_idx}). "
            "Check removed_layers list or metric shapes."
        )

    return full_to_pruned, pruned_to_full


def main():
    parser = argparse.ArgumentParser(
        description="Compute Spearman correlation of FLAP importance "
                    "between full and pruned models."
    )
    parser.add_argument("--full_attn_metric", type=str, required=True)
    parser.add_argument("--full_mlp_metric", type=str, required=True)
    parser.add_argument("--pruned_attn_metric", type=str, required=True)
    parser.add_argument("--pruned_mlp_metric", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True,
                        help="Directory to save spearman_layer_corr.tsv")

    # 你的移除层列表：full model 中被剪掉的层索引
    parser.add_argument(
        "--removed_layers", type=int, nargs="+", required=False,
        default=[23, 24, 25, 26, 27, 22, 28, 21, 29, 19, 20,
                 18, 30, 17, 10, 13],
        help="Indices of layers removed from the full model to obtain the "
             "pruned model, e.g. 23 24 25 26 ..."
    )

    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    print("[INFO] Loading metrics...")
    full_attn = np.load(args.full_attn_metric)   # [L_full, D_attn]
    full_mlp = np.load(args.full_mlp_metric)     # [L_full, D_mlp]
    pruned_attn = np.load(args.pruned_attn_metric)   # [L_pruned, D_attn]
    pruned_mlp = np.load(args.pruned_mlp_metric)     # [L_pruned, D_mlp]

    L_full, D_attn_full = full_attn.shape
    L_pruned, D_attn_pruned = pruned_attn.shape
    L_full_mlp, D_mlp_full = full_mlp.shape
    L_pruned_mlp, D_mlp_pruned = pruned_mlp.shape

    print(f"[INFO] full_attn shape   = {full_attn.shape}")
    print(f"[INFO] pruned_attn shape = {pruned_attn.shape}")
    print(f"[INFO] full_mlp shape    = {full_mlp.shape}")
    print(f"[INFO] pruned_mlp shape  = {pruned_mlp.shape}")

    assert L_full == L_full_mlp, "Full attn/mlp must have same num_layers."
    assert L_pruned == L_pruned_mlp, "Pruned attn/mlp must have same num_layers."
    assert D_attn_full == D_attn_pruned, "Attn hidden size must match."
    assert D_mlp_full == D_mlp_pruned, "MLP hidden size must match."

    # 构造 pruned_layer_idx -> full_layer_idx 的映射
    full_to_pruned, pruned_to_full = build_layer_mapping(
        num_layers_full=L_full,
        removed_layers=args.removed_layers,
        num_layers_pruned=L_pruned,
    )

    print("[INFO] Layer mapping (pruned -> full):")
    for p_idx in range(L_pruned):
        f_idx = pruned_to_full.get(p_idx, None)
        if f_idx is None:
            print(f"  pruned layer {p_idx} -> <None> (check removed_layers)")
        else:
            print(f"  pruned layer {p_idx} -> full layer {f_idx}")
    print("")

    # 逐层计算 Spearman
    attn_spearman = np.zeros(L_pruned, dtype=float)
    mlp_spearman = np.zeros(L_pruned, dtype=float)

    for p_idx in range(L_pruned):
        f_idx = pruned_to_full.get(p_idx, None)
        if f_idx is None:
            attn_spearman[p_idx] = np.nan
            mlp_spearman[p_idx] = np.nan
            continue

        attn_full_vec = full_attn[f_idx]
        attn_pruned_vec = pruned_attn[p_idx]
        mlp_full_vec = full_mlp[f_idx]
        mlp_pruned_vec = pruned_mlp[p_idx]

        attn_spearman[p_idx] = spearman_corr(attn_full_vec, attn_pruned_vec)
        mlp_spearman[p_idx] = spearman_corr(mlp_full_vec, mlp_pruned_vec)

    # 打印一个简要表格
    print("Layer\tFullLayer\tAttnSpearman\tMLPSpearman")
    for p_idx in range(L_pruned):
        f_idx = pruned_to_full.get(p_idx, -1)
        print(f"{p_idx}\t{f_idx}\t{attn_spearman[p_idx]:.4f}\t{mlp_spearman[p_idx]:.4f}")

    # 保存结果
    out_path = os.path.join(args.output_dir, "spearman_layer_corr.tsv")
    with open(out_path, "w") as f:
        f.write("pruned_layer\tfull_layer\tattn_spearman\tmlp_spearman\n")
        for p_idx in range(L_pruned):
            f_idx = pruned_to_full.get(p_idx, -1)
            f.write(
                f"{p_idx}\t{f_idx}\t"
                f"{attn_spearman[p_idx]:.6f}\t{mlp_spearman[p_idx]:.6f}\n"
            )

    print(f"[DONE] Saved layer-wise Spearman correlations to {out_path}")


if __name__ == "__main__":
    main()
