# pgr_pruning.py
from typing import Dict, List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_pruning import function  # torch_pruning.function 别名
from torch_pruning.dependency import Group
from torch_pruning.pruner import function as TPFn
from torch_pruning.pruner.importance import (
    GroupMagnitudeImportance,  # 保持与项目内导入路径一致
    Importance,
)
from tqdm.auto import tqdm

from assp.models.sammed2d.modeling.image_encoder import Attention, ImageEncoderViT

__all__ = [
    "collect_target_and_ignored_layers",
    "right_svd_basis",
    "merge_union_basis",
    "two_fold_split",
    "_vec_row",
    "_vec_col",
    "_param_slice_linear",
    "collect_bank",
    "PGRGroupImportance",
    "RelPosManager",
    "delete_pos_embedding",
    "restore_pos_embedding",
    "_get_encoder",
]


def collect_target_and_ignored_layers(encoder, mode: str = "mode2"):
    """
    根据剪枝模式收集“根模块”(target_modules) 与 “忽略模块”(ignored_layers)，并给出 q/k/v 的 num_heads 映射。

    Args:
        encoder: ImageEncoderViT 实例
        mode: "mode1" | "mode2" | "mode3" | "mode4"
            - mode1: 剪 head_dim 与 MLP 隐层 → 以 attn.{q,k,v} 与 mlp.lin1 为根
            - mode2: 剪 embed_dim           → 以 attn.proj 与 mlp.lin2 为根；并将 patch_embed 作为根（若存在）
            - mode3: 仅剪 Adapter 内瓶颈   → 以 Adapter.channel[0] 与 Adapter.spatial[0] 为根
            - mode4: mode1 ∪ mode3（同时剪骨干 head_dim/MLP 与 Adapter 瓶颈）
    """
    import torch.nn as nn

    def _get_patch_embed_conv(enc):
        """尽量稳健地找到 patch embedding 的 Conv2d（mode2 要作为根）。找不到就返回 None。"""
        if hasattr(enc, "patch_embed"):
            pe = enc.patch_embed
            for attr in ("proj", "projection", "conv"):
                if hasattr(pe, attr) and isinstance(getattr(pe, attr), nn.Conv2d):
                    return getattr(pe, attr)
            if isinstance(pe, nn.Conv2d):
                return pe
        # 兜底：遍历所有 Conv2d，优先挑 in_channels==3 的作为 patch_embed
        for m in encoder.modules():
            if isinstance(m, nn.Conv2d):
                try:
                    if getattr(m, "in_channels", None) == 3:
                        return m
                except Exception:
                    pass
        return None

    if mode not in ("mode1", "mode2", "mode3", "mode4"):
        raise ValueError(f"Unknown prune mode: {mode}")

    target_modules = []
    num_heads_map = {}

    # ------- 收集根（注意：mode4 = mode1 ∪ mode3，故使用独立 if 而非 elif） -------
    for blk in encoder.blocks:
        attn, mlp = blk.attn, blk.mlp
        H = attn.num_heads

        # 始终维护 q/k/v 的 head 映射
        if hasattr(attn, "q"):
            num_heads_map[attn.q] = H
        if hasattr(attn, "k"):
            num_heads_map[attn.k] = H
        if hasattr(attn, "v"):
            num_heads_map[attn.v] = H

        # —— mode1 的根：Q/K/V（按 head_dim 裁剪）与 MLP.lin1（隐层）
        if mode in ("mode1", "mode4"):
            for linear in [
                getattr(attn, "q", None),
                getattr(attn, "k", None),
                getattr(attn, "v", None),
                getattr(mlp, "lin1", None),
            ]:
                if isinstance(linear, nn.Linear):
                    target_modules.append(linear)

        # —— mode3 的根：Adapter 瓶颈（channel[0]: Linear，spatial[0]: Conv2d）
        if mode in ("mode3", "mode4"):
            adapter = getattr(blk, "Adapter", None)
            if adapter is not None:
                # channel[0]: 第一个 Linear
                ch0 = None
                if hasattr(adapter, "channel") and isinstance(adapter.channel, nn.Sequential):
                    if len(adapter.channel) > 0 and isinstance(adapter.channel[0], nn.Linear):
                        ch0 = adapter.channel[0]
                if isinstance(ch0, nn.Linear):
                    target_modules.append(ch0)

                # spatial[0]: 第一个 Conv2d
                sp0 = None
                if hasattr(adapter, "spatial") and isinstance(adapter.spatial, nn.Sequential):
                    if len(adapter.spatial) > 0 and isinstance(adapter.spatial[0], nn.Conv2d):
                        sp0 = adapter.spatial[0]
                if isinstance(sp0, nn.Conv2d):
                    target_modules.append(sp0)

        # —— mode2 的根：attn.proj 与 mlp.lin2（embed_dim 方向）
        if mode == "mode2":
            proj = getattr(attn, "proj", None)
            lin2 = getattr(mlp, "lin2", None)
            if isinstance(proj, nn.Linear):
                target_modules.append(proj)
            if isinstance(lin2, nn.Linear):
                target_modules.append(lin2)

    # mode2：将 patch_embed 作为根（保持原逻辑不变）
    if mode == "mode2":
        pe_conv = _get_patch_embed_conv(encoder)
        if isinstance(pe_conv, nn.Conv2d):
            # 作为首个根，便于依赖从输入侧传播
            target_modules.insert(0, pe_conv)

    # 去重（保持顺序）
    seen = set()
    uniq_targets = []
    for m in target_modules:
        if (m is not None) and (m not in seen):
            uniq_targets.append(m)
            seen.add(m)
    target_modules = uniq_targets

    # ------- ignored_layers：把常见“可裁剪层”里不在 target_set 的都加入忽略清单 -------
    target_set = set(target_modules)
    ignored_layers = []
    for m in encoder.modules():
        if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d, nn.LayerNorm, nn.BatchNorm2d)):
            if m not in target_set:
                ignored_layers.append(m)

    return target_modules, ignored_layers, num_heads_map


# ---------- helpers for PGR ----------
def right_svd_basis(X, max_rank, energy_thresh=None):
    """
    方向基（右奇异向量）：
      G=UΣV^T，返回 V_{:,1:k}。k 的理想选取：k* = min{k : cumsum(Σ^2)/sum(Σ^2) ≥ keep}
      现实现为：k = #{cumsum(Σ^2)/sum(Σ^2) ≤ keep}（端点差异，逻辑不变）
    """
    if X is None or X.numel() == 0:
        return None
    U, S, Vh = torch.linalg.svd(X, full_matrices=False)
    V = Vh.transpose(0, 1)
    if energy_thresh is not None and S.numel() > 0:
        e = S**2
        e = e / e.sum().clamp_min(1e-12)
        cum = torch.cumsum(e, dim=0)
        k = int((cum <= energy_thresh).sum().item())
        k = max(k, 1)
    else:
        k = min(max_rank, V.shape[1])
    k = max(1, min(k, max_rank, V.shape[1]))
    return V[:, :k]


def merge_union_basis(Va, Vb, max_rank, energy_thresh=None):
    """
    公共子空间合并（与 Gram 等价）:
      C=[U_pre U_down]；Gram: C^T C = VΛV^T，U_∪ = C V_k Λ_k^{-1/2}
      这里用 SVD：C=UΣV^T，取 U_{:,1:k} 作为 U_∪
    """
    if Va is None and Vb is None:
        return None
    if Va is None:
        C = Vb
    elif Vb is None:
        C = Va
    else:
        # ---- 关键补丁：把行数（环境维度） pad 到一致 ----
        da, db = Va.shape[0], Vb.shape[0]
        if da != db:
            D = max(da, db)
            if da < D:
                Va = torch.cat([Va, Va.new_zeros(D - da, Va.shape[1])], dim=0)
            if db < D:
                Vb = torch.cat([Vb, Vb.new_zeros(D - db, Vb.shape[1])], dim=0)
        C = torch.cat([Va, Vb], dim=1)

    U, S, Vh = torch.linalg.svd(C, full_matrices=False)
    if energy_thresh is not None and S.numel() > 0:
        e = S**2
        e = e / e.sum().clamp_min(1e-12)
        cum = torch.cumsum(e, dim=0)
        k = int((cum <= energy_thresh).sum().item())
        k = max(k, 1)
    else:
        k = min(max_rank, U.shape[1])
    k = max(1, min(k, max_rank, U.shape[1]))
    return U[:, :k]


def two_fold_split(n, seed=0):
    """两折近似 LOO：随机划分索引为 S1,S2；评 S1 用 U_∪,[2]，评 S2 用 U_∪,[1]。"""
    idx = torch.arange(n)
    perm = torch.randperm(n, generator=torch.Generator().manual_seed(seed)) if seed is not None else idx
    mid = n // 2
    return perm[:mid].tolist(), perm[mid:].tolist()


def _vec_row(mat, idx):
    return mat[idx].reshape(-1)


def _vec_col(mat_t, idx):
    return mat_t[idx].reshape(-1)  # mat_t = grad^T


@torch.no_grad()
def _param_slice_linear(layer: nn.Linear, mode: str, idx: int):
    """组级参数切片 θ_g 的一部分：行(out)取 W[idx,:]；列(in)取 W[:,idx]。"""
    W = layer.weight.data
    return (W[idx, :] if mode == "out" else W[:, idx]).reshape(-1)


@torch.no_grad()
def _param_slice_conv(layer: nn.Module, mode: str, idx: int):
    """
    对 Conv2d / ConvTranspose2d 做“组级参数切片” θ_g 的一部分：
      - 对 Conv2d (W: [out, in, kh, kw]):
          mode='out' -> W[idx, :, :, :].reshape(-1)
          mode='in'  -> W[:, idx, :, :].reshape(-1)
      - 对 ConvTranspose2d (W: [in, out/groups, kh, kw]):
          这里依 torch 的权重布局进行切片：
          mode='out' -> W[:, idx, :, :].reshape(-1)      # 注意 out 在第2维
          mode='in'  -> W[idx, :, :, :].reshape(-1)
    """
    W = layer.weight.data
    if isinstance(layer, nn.ConvTranspose2d):
        if mode == "out":
            return W[:, idx, :, :].reshape(-1)
        else:
            return W[idx, :, :, :].reshape(-1)
    else:
        # 默认为 Conv2d
        if mode == "out":
            return W[idx, :, :, :].reshape(-1)
        else:
            return W[:, idx, :, :].reshape(-1)


def collect_bank(
    loader,
    model,
    use_adapter: bool,
    accelerator=None,
    reduction: str = "mean",
    loss_fn=None,
    *,
    mode: str = "bank",  # "bank" | "param" | "both"
    average_param_grad: bool = True,  # mode 含 "param" 时，是否对 .grad 做按步数的平均
):
    """
    收集权重梯度的“行/列银行”，或直接把梯度累到参数的 .grad 中（可二选一或同时做）。

    Args:
        loader: 已构建好的 DataLoader（可已由 accelerator.prepare 处理）
        model: 参与前/反传的模型（若使用 Accelerator，建议外层先 prepare）
        use_adapter: 是否启用 adapter（若模型有 .adapter_train 属性则切换）
        accelerator: 可选 Accelerate.Accelerator；提供则用其 backward/reduce
        reduction: 多进程聚合方式 'mean' 或 'sum'（用于 bank 聚合）
        loss_fn: 可选自定义损失函数，签名 loss_fn(out, batch)->scalar；
                 若为 None，默认用 (y**2).mean() 作为占位损失
        mode: "bank"（仅收集 bank）、"param"（仅累到 .grad）、"both"（两者都做）
        average_param_grad: 当 mode 含 "param" 时，是否在结束后对 .grad 除以全局步数

    Returns:
        row_bank, col_bank
        - 当 mode=="bank"：返回按层聚合后的平均梯度（行/列视角）
        - 当 mode=="param"：返回空 dict（保持接口兼容；有效梯度保存在各参数 .grad 中）
        - 当 mode=="both"：既返回 bank，又在参数 .grad 中保留累计/平均后的梯度
    """
    import torch
    import torch.nn as nn
    from tqdm import tqdm

    assert mode in ("bank", "param", "both"), f"Unsupported mode: {mode}"

    # —— 训练/推理模式：保持与你原实现一致（默认 train()）
    model.train()

    # —— 若模型支持 adapter 开关，按 use_adapter 切换
    if hasattr(model, "adapter_train"):
        model.adapter_train = bool(use_adapter)

    # —— 根据 use_adapter 临时切换 Block 的 skip_adapter，并在结束时恢复
    enc = getattr(model, "image_encoder", model)
    prev_skip = None
    if hasattr(enc, "blocks") and len(enc.blocks) > 0 and hasattr(enc.blocks[0], "skip_adapter"):
        prev_skip = [blk.skip_adapter for blk in enc.blocks]
        for blk in enc.blocks:
            blk.skip_adapter = not use_adapter  # 启用 adapter => 不跳过

    # —— 设备与进度条
    device = accelerator.device if accelerator is not None else next(model.parameters()).device
    is_main_process = accelerator.is_main_process if accelerator is not None else True

    # —— row/col bank（仅在需要时建立）
    need_bank = mode in ("bank", "both")
    row_bank, col_bank = ({}, {}) if need_bank else ({}, {})

    # —— 梯度初始化
    #     - 对于 "param" 模式，我们希望跨 batch 累加梯度，因此只在循环前 zero_grad 一次
    #     - 对于 "bank" 模式，我们每个 batch 后会清零 .grad（保持你原先行为）
    model.zero_grad(set_to_none=True)

    steps = 0
    try:
        # DataLoader 可由 accelerator.prepare 包装（与原逻辑一致）
        loader = accelerator.prepare(loader) if accelerator is not None else loader
        progress_bar = tqdm(loader, disable=not is_main_process, desc="Collecting Gradients")

        for batch in progress_bar:
            # —— 取输入张量：兼容 tensor / dict / tuple 样式
            if torch.is_tensor(batch):
                x = batch
            elif isinstance(batch, dict):
                x = batch.get("input_image", None)
                if x is None:
                    if isinstance(batch.get("images", None), torch.Tensor):
                        x = batch["images"]
                    else:
                        continue
            elif isinstance(batch, (list, tuple)) and len(batch) > 0 and torch.is_tensor(batch[0]):
                x = batch[0]
            else:
                continue

            x = x.to(device=device, dtype=torch.float32)

            # —— 前向 + 占位/自定义损失
            out = model(x)
            y = out[0] if isinstance(out, (tuple, list)) else out
            loss = loss_fn(out, batch) if loss_fn is not None else (y**2).mean()

            # —— 反传（Accelerate or vanilla）
            if accelerator is not None:
                accelerator.backward(loss)
            else:
                loss.backward()

            # ====== 聚合权重梯度到 bank（仅当需要） ======
            if need_bank:
                for m in model.modules():
                    # Linear：保持原实现
                    if isinstance(m, nn.Linear) and m.weight.grad is not None:
                        g = m.weight.grad.detach()
                        if m not in row_bank:
                            row_bank[m] = g.clone()
                            col_bank[m] = g.transpose(0, 1).clone()
                        else:
                            row_bank[m] += g
                            col_bank[m] += g.transpose(0, 1)

                    # Conv2d
                    elif isinstance(m, nn.Conv2d) and m.weight.grad is not None:
                        g = m.weight.grad.detach()  # [out, in, kh, kw]
                        if m not in row_bank:
                            row_bank[m] = g.clone()  # 行视角：输出通道优先
                            col_bank[m] = g.permute(1, 0, 2, 3).clone()  # [in, out, kh, kw]
                        else:
                            row_bank[m] += g
                            col_bank[m] += g.permute(1, 0, 2, 3)

                    # ConvTranspose2d
                    elif isinstance(m, nn.ConvTranspose2d) and m.weight.grad is not None:
                        g = m.weight.grad.detach()  # [in, out/groups, kh, kw]
                        if m not in row_bank:
                            row_bank[m] = g.clone()
                            col_bank[m] = g.clone()  # 第一维本来就是 in
                        else:
                            row_bank[m] += g
                            col_bank[m] += g

            # —— 每个 batch 结束后的梯度处理
            if mode == "bank":
                # 仅 bank：按你原先逻辑，每步后清掉 .grad，避免跨步累加
                for p in model.parameters():
                    if p.grad is not None:
                        p.grad.zero_()
            else:
                # "param" 或 "both"：不清零，让 .grad 自然累加
                pass

            steps += 1

    finally:
        # —— 恢复 skip_adapter
        if prev_skip is not None:
            for blk, s in zip(enc.blocks, prev_skip):
                blk.skip_adapter = s

    # ====== 跨进程聚合（bank 用） ======
    if need_bank and accelerator is not None:
        # 先“求和”聚合（与原逻辑一致）
        for k in list(row_bank.keys()):
            row_bank[k] = accelerator.reduce(row_bank[k].contiguous(), reduction="sum")
        for k in list(col_bank.keys()):
            col_bank[k] = accelerator.reduce(col_bank[k].contiguous(), reduction="sum")

    # ====== 统计全局步数（用于 bank 的平均；也可用于 param 的平均） ======
    steps = max(steps, 1)
    device_tensor = torch.tensor(steps, device=device, dtype=torch.float32)
    if accelerator is not None:
        global_steps = int(accelerator.reduce(device_tensor, reduction="sum").item())
    else:
        global_steps = int(device_tensor.item())
    global_steps = max(global_steps, 1)

    # ====== bank 模式的“按全局步数做平均” ======
    if need_bank:
        if reduction == "mean":
            for k in row_bank:
                row_bank[k] /= float(global_steps)
            for k in col_bank:
                col_bank[k] /= float(global_steps)
        elif reduction == "sum":
            pass
        else:
            raise ValueError(f"Unsupported reduction: {reduction}")

    # ====== param 模式的“对 .grad 做平均”（可选，不影响排序） ======
    if mode in ("param", "both") and average_param_grad:
        for p in model.parameters():
            if p.grad is not None:
                p.grad.div_(float(global_steps))

    # —— 返回值保持与旧接口兼容
    if mode == "param":
        return {}, {}
    else:
        return row_bank, col_bank


def collect_bank_old(
    loader,
    model,
    use_adapter: bool,
    accelerator=None,
    reduction: str = "mean",
    loss_fn=None,
):
    """收集权重梯度的“行/列银行”（线性层仍与 V1.0 保持一致；可选 Accelerator 同步）。
    Args:
        loader: 已构建好的 DataLoader（可已由 accelerator.prepare 处理，也可未处理——单机也能跑）
        model: 参与前/反传的模型（外部若用 Accelerator，需在外层 prepare 过）
        use_adapter: 是否启用 adapter（若模型有 .adapter_train 属性会切换）
        accelerator: 可选 Accelerate.Accelerator；若提供则用其 backward 与 reduce
        reduction: 多进程聚合方式 'mean' 或 'sum'（默认 'mean'）
    Returns:
        row_bank: {module:avg_grad_weight}[out_features,in_features]
        col_bank: {module:avg_grad_weight^T}[in_features,out_features]
    """
    model.train()
    if hasattr(model, "adapter_train"):
        model.adapter_train = bool(use_adapter)

    # —— 新增（最小改动）：基于 use_adapter 切换每个 Block 的 skip_adapter，并在结束时恢复
    enc = getattr(model, "image_encoder", model)
    prev_skip = None
    if hasattr(enc, "blocks") and len(enc.blocks) > 0 and hasattr(enc.blocks[0], "skip_adapter"):
        prev_skip = [blk.skip_adapter for blk in enc.blocks]
        for blk in enc.blocks:
            blk.skip_adapter = not use_adapter  # 启用 adapter => 不跳过

    row_bank, col_bank, steps = {}, {}, 0
    device = accelerator.device if accelerator is not None else next(model.parameters()).device
    for p in model.parameters():
        p.grad = None

    try:
        is_main_process = accelerator.is_main_process if accelerator is not None else True
        loader = accelerator.prepare(loader)
        progress_bar = tqdm(loader, disable=not is_main_process, desc="Collecting Gradients")

        for batch in progress_bar:
            # 取输入张量：兼容你的 DummySet(tensor) 与 下游/上游 dict 样式
            if torch.is_tensor(batch):
                x = batch
            elif isinstance(batch, dict):
                x = batch.get("input_image", None)
                if x is None:
                    # 尽量容错：若数据是 tuple/list，尝试第一个元素
                    if isinstance(batch.get("images", None), torch.Tensor):
                        x = batch["images"]
                    else:
                        # 最后兜底：直接跳过该 batch
                        continue
            elif isinstance(batch, (list, tuple)) and len(batch) > 0 and torch.is_tensor(batch[0]):
                x = batch[0]
            else:
                continue
            x = x.to(device=device, dtype=torch.float32)
            out = model(x)
            y = out[0] if isinstance(out, (tuple, list)) else out
            loss = loss_fn(out, batch) if loss_fn is not None else (y**2).mean()
            if accelerator is not None:
                accelerator.backward(loss)
            else:
                loss.backward()
            # 累加权重梯度（与 V1.0 逻辑一致：线性层）
            for m in model.modules():
                if isinstance(m, nn.Linear) and m.weight.grad is not None:
                    g = m.weight.grad.detach()
                    if m not in row_bank:
                        row_bank[m] = g.clone()
                        col_bank[m] = g.transpose(0, 1).clone()
                    else:
                        row_bank[m] += g
                        col_bank[m] += g.transpose(0, 1)
            # 清梯度（避免 set_to_none 关键字在某些张量子类报错）
            for p in model.parameters():
                if p.grad is not None:
                    p.grad.zero_()
            steps += 1

    finally:
        # —— 恢复 skip_adapter，避免影响后续流程
        if prev_skip is not None:
            for blk, s in zip(enc.blocks, prev_skip):
                blk.skip_adapter = s

    # 跨进程聚合（可选）：保证所有 rank 得到一致银行
    if accelerator is not None:
        # 先跨卡“求和”聚合
        for k in list(row_bank.keys()):
            row_bank[k] = accelerator.reduce(row_bank[k].contiguous(), reduction="sum")
        for k in list(col_bank.keys()):
            col_bank[k] = accelerator.reduce(col_bank[k].contiguous(), reduction="sum")

    steps = max(steps, 1)
    global_steps = steps
    if accelerator is not None:
        # 汇总全局总步数
        global_steps = int(accelerator.reduce(torch.tensor(steps, device=device, dtype=torch.float32), reduction="sum").item())
    global_steps = max(global_steps, 1)

    # 按全局 batch 数做平均
    for k in row_bank:
        row_bank[k] /= float(global_steps)
    for k in col_bank:
        col_bank[k] /= float(global_steps)
    return row_bank, col_bank


# ---------- PGR Importance ----------
class PGRGroupImportance(Importance):
    """
    组级 PGR 重要性（以依赖组 DG 为最小粒度）：
      组向量：g_g = [vec(∇T_1[Ω_1]); ...; vec(∇T_m[Ω_m])] ∈ R^d
      子空间（方向基，先行归一化）：U_pre ← right_svd_basis( normalize_rows(G_pre) )，
                                U_down_[fold] ← right_svd_basis( normalize_rows(G_down_[fold]) )
      公共子空间（合并）：U_∪ ← merge_union_basis(U_pre,U_down_[fold])
      残量评分（两折 LOO）：
        R_g^2 = ||g_g||_2^2 − ||U_∪^T g_g||_2^2，score_g = sqrt(max(R_g^2,0)) [可选乘 ||θ_g||_2]
    """

    def __init__(
        self,
        bank_pre_row,
        bank_pre_col,
        bank_down_row,
        bank_down_col,
        rank: int = 16,
        energy_thresh: float = None,
        two_fold: bool = True,
        weight_scale: bool = True,
        fallback_p: int = 2,
        seed: int = 0,
    ):
        self.Gpre_row = bank_pre_row
        self.Gpre_col = bank_pre_col
        self.Gdown_row = bank_down_row
        self.Gdown_col = bank_down_col
        self.rank = rank
        self.energy_thresh = energy_thresh
        self.two_fold = two_fold
        self.weight_scale = weight_scale
        self.fallback_p = fallback_p
        self.seed = seed

    @staticmethod
    def _normalize_rows(X: torch.Tensor, eps: float = 1e-12):
        if X is None:
            return None
        if X.dim() != 2 or X.numel() == 0:
            return X
        nrm = X.norm(dim=1, keepdim=True).clamp_min(eps)
        return X / nrm

    @torch.no_grad()
    def __call__(self, group):
        roots = []
        for i, (_dep, _idxs) in enumerate(group):
            r = getattr(group[i], "root_idxs", None)
            if r is not None:
                roots.extend(r)
        if len(roots) == 0:
            return None
        roots = sorted(set(int(r) for r in roots))
        n = len(roots)
        rid2pos = {r: i for i, r in enumerate(roots)}
        fr_pre = [[] for _ in range(n)]
        fr_down = [[] for _ in range(n)]
        fr_param = [[] for _ in range(n)]

        # ==== 收敛：支持 Linear + Conv（含 ConvTranspose2d） ====
        # 构建“剪输出/剪输入”的函数集合（不同版本 torch_pruning 函数名略有差异，按存在性收集）
        OUT_FNS = [TPFn.prune_linear_out_channels]
        IN_FNS = [TPFn.prune_linear_in_channels]
        for maybe in ("prune_conv_out_channels", "prune_depthwise_conv_out_channels", "prune_convtranspose_out_channels"):
            f = getattr(TPFn, maybe, None)
            if f is not None:
                OUT_FNS.append(f)
        for maybe in ("prune_conv_in_channels", "prune_convtranspose_in_channels"):
            f = getattr(TPFn, maybe, None)
            if f is not None:
                IN_FNS.append(f)

        for i, (dep, idxs) in enumerate(group):
            layer = getattr(dep, "layer", None)
            if layer is None:
                layer = getattr(dep, "target", None)
                layer = getattr(layer, "module", None)
            prune_fn = getattr(dep, "pruning_fn", None)
            if prune_fn is None:
                prune_fn = getattr(dep, "handler", None)
            root_idxs = getattr(group[i], "root_idxs", None)
            if layer is None or root_idxs is None or idxs is None or len(idxs) == 0:
                continue

            if prune_fn in OUT_FNS:
                mode = "out"
                Gp = self.Gpre_row.get(layer, None)
                Gd = self.Gdown_row.get(layer, None)
            elif prune_fn in IN_FNS:
                mode = "in"
                Gp = self.Gpre_col.get(layer, None)
                Gd = self.Gdown_col.get(layer, None)
            else:
                continue

            for j, loc in enumerate(idxs):
                r = int(root_idxs[j])
                pos = rid2pos.get(r, None)
                if pos is None:
                    continue
                if Gp is not None:
                    fr_pre[pos].append(_vec_row(Gp, loc))
                if Gd is not None:
                    fr_down[pos].append(_vec_row(Gd, loc))
                # 参数切片：Linear / Conv 分别处理
                if isinstance(layer, nn.Linear):
                    fr_param[pos].append(_param_slice_linear(layer, mode, int(loc)))
                elif isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
                    fr_param[pos].append(_param_slice_conv(layer, mode, int(loc)))

        gpre_list, gdown_list, p_list = [], [], []
        for r in range(n):
            if len(fr_down[r]) == 0 and len(fr_pre[r]) == 0:
                gdown_list.append(None)
                gpre_list.append(None)
                p_list.append(None)
                continue
            ddown = torch.cat(fr_down[r], dim=0).reshape(-1) if len(fr_down[r]) > 0 else None
            dpre = torch.cat(fr_pre[r], dim=0).reshape(-1) if len(fr_pre[r]) > 0 else None
            gdown_list.append(ddown)
            gpre_list.append(dpre)
            if len(fr_param[r]) > 0:
                theta = torch.cat(fr_param[r], dim=0).reshape(-1)
            else:
                theta = torch.zeros_like(ddown) if ddown is not None else (torch.zeros_like(dpre) if dpre is not None else torch.zeros(1))
            p_list.append(theta)

        def _stack_dense(vecs):
            val = [v for v in vecs if v is not None]
            if len(val) == 0:
                return None
            d = max(v.numel() for v in val)
            out = []
            for v in vecs:
                if v is None:
                    continue
                out.append(v if v.numel() >= d else F.pad(v, (0, d - v.numel())))
            return torch.stack(out, dim=0)

        # 原始（幅值保留）矩阵：用于残量投影
        Gd = _stack_dense(gdown_list)
        Gp = _stack_dense(gpre_list)
        if Gd is None and Gp is None:
            scores = []
            for r in range(n):
                if p_list[r] is not None and p_list[r].numel() > 0:
                    scores.append(p_list[r].norm(p=self.fallback_p))
                elif gdown_list[r] is not None:
                    scores.append(gdown_list[r].abs().sum())
                else:
                    scores.append(torch.tensor(0.0))
            return torch.stack(scores, dim=0)

        # 方向矩阵（逐行归一化）：用于构基
        Gd_dir = self._normalize_rows(Gd)
        Gp_dir = self._normalize_rows(Gp)

        if self.two_fold:
            S1, S2 = two_fold_split(n, seed=self.seed)

            def _rows_select(M, sel):
                if M is None or len(sel) == 0:
                    return None
                rows = []
                cur = 0
                for r in range(n):
                    if gdown_list[r] is not None:
                        if r in sel:
                            rows.append(cur)
                        cur += 1
                if len(rows) == 0:
                    return None
                return M[torch.tensor(rows, device=M.device)]

            Upre = right_svd_basis(Gp_dir, self.rank, energy_thresh=self.energy_thresh)
            Gd1_dir, Gd2_dir = _rows_select(Gd_dir, S1), _rows_select(Gd_dir, S2)
            Udown1 = right_svd_basis(Gd1_dir, self.rank, energy_thresh=self.energy_thresh)
            Udown2 = right_svd_basis(Gd2_dir, self.rank, energy_thresh=self.energy_thresh)
            Ucup1 = merge_union_basis(Upre, Udown1, self.rank, energy_thresh=self.energy_thresh)
            Ucup2 = (
                merge_union_basis(
                    Upre,
                    Udown2 + self.rank * 0,  # 维持原有设备/类型对齐技巧
                    self.rank,
                    energy_thresh=self.energy_thresh,
                )
                if Udown2 is not None
                else merge_union_basis(Upre, None, self.rank, energy_thresh=self.energy_thresh)
            )
        else:
            Upre = right_svd_basis(Gp_dir, self.rank, energy_thresh=self.energy_thresh)
            Udown = right_svd_basis(Gd_dir, self.rank, energy_thresh=self.energy_thresh)
            Ucup = merge_union_basis(Upre, Udown, self.rank, energy_thresh=self.energy_thresh)

        # 残量评分
        scores = []
        for r in range(n):
            gd = gdown_list[r]
            if gd is None:
                sc = (
                    p_list[r].norm(p=self.fallback_p)
                    if p_list[r] is not None and p_list[r].numel() > 0
                    else torch.tensor(
                        0.0,
                        device=(Upre.device if Upre is not None else (Udown.device if (not self.two_fold and Udown is not None) else "cpu")),
                    )
                )
                scores.append(sc)
                continue
            U = Ucup2 if (self.two_fold and (r in S1)) else (Ucup1 if self.two_fold else Ucup)
            if U is None or U.numel() == 0:
                resid2 = gd.pow(2).sum()
            else:
                proj = U.transpose(0, 1) @ gd
                resid2 = gd.pow(2).sum() - proj.pow(2).sum()
                resid2 = torch.clamp(resid2, min=0.0)
            sc = torch.sqrt(resid2)
            if self.weight_scale and p_list[r] is not None and p_list[r].numel() > 0:
                sc = sc * p_list[r].norm(p=2)
            scores.append(sc)
        return torch.stack(scores, dim=0)


class GroupDownGradWGImportance(GroupMagnitudeImportance):
    """
    Disturb-style grouped importance using |W * G| per prunable channel, with group aggregation.

    - 默认使用当前 layer.weight.grad（与 DisturbImportance 一致）；
      如在 __init__ 传入 bank_row / bank_col，则优先使用银行梯度（支持只用下游统计梯度的场景）。
    - 覆盖 Linear / Conv2d / ConvTranspose2d 的 out/in 通道视角（含 transposed 情况）。
    - 可选 ch_groups，在特定拓扑下模仿 Disturb 的“跨层通道分组齐次化”启发式。
    - 组内聚合/归一化沿用 GroupMagnitudeImportance 的 _reduce/_normalize，确保与 tp 生态一致。
    """

    def __init__(
        self,
        bank_row: Optional[Dict[nn.Module, torch.Tensor]] = None,
        bank_col: Optional[Dict[nn.Module, torch.Tensor]] = None,
        *,
        p: int = 2,
        group_reduction: str = "mean",
        normalizer: Optional[str] = "mean",
        target_types: List[type] = [nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm],
        bias: bool = False,
    ):
        super().__init__(p=p, group_reduction=group_reduction, normalizer=normalizer, bias=bias, target_types=target_types)
        self.bank_row = bank_row if bank_row is not None else {}
        self.bank_col = bank_col if bank_col is not None else {}

    @staticmethod
    def _get_layer_and_handler(dep):
        """
        兼容不同版本的 Dependency 表示：
        - 有的版本是 dep.layer / dep.pruning_fn
        - 有的是 dep.target.module / dep.handler
        """
        layer = getattr(dep, "layer", None)
        prune_fn = getattr(dep, "pruning_fn", None)
        if layer is None:
            target = getattr(dep, "target", None)
            if target is not None:
                layer = getattr(target, "module", None)
            if prune_fn is None:
                prune_fn = getattr(dep, "handler", None)
        return layer, prune_fn

    @staticmethod
    def _weight_and_grad_or_bank_out(layer, idxs, bank_tensor: Optional[torch.Tensor], use_transposed: bool):
        """
        输出通道剪枝视角：返回
          w_sel:  [len(idxs), ...] 展平到 [len(idxs), -1]
          g_sel:  同形，来自 grad 或 bank
        - Linear/Conv2d: 直接按 idxs 选取第 0 维
        - ConvTranspose/转置视角: 先把 out 维转到第 0 维再索引
        """
        W = layer.weight
        G = None

        if bank_tensor is not None:
            Gw = bank_tensor
        else:
            Gw = None if (W.grad is None) else W.grad

        if use_transposed:
            # 逻辑 out 放到 dim=0 上：Linear 无 transposed 属性；ConvTranspose/某些实现需要该分支
            w_sel = W.data.transpose(1, 0)[idxs].flatten(1)
            g_sel = None if Gw is None else Gw.detach().transpose(1, 0)[idxs].flatten(1)
        else:
            w_sel = W.data[idxs].flatten(1)
            g_sel = None if Gw is None else Gw.detach()[idxs].flatten(1)

        return w_sel, g_sel

    @staticmethod
    def _weight_and_grad_or_bank_in(layer, prune_fn, bank_tensor_row: Optional[torch.Tensor], bank_tensor_col: Optional[torch.Tensor], use_transposed: bool):
        """
        输入通道剪枝视角：返回
          w_in: 展平为 [in_channels*(...), -1] 中的每行对应一个“输入通道单位”
          g_in: 同形，来自 grad 或 bank
        - Linear: 转置到 [in, out]
        - Conv2d: permute 到 [in, out, kh, kw]
        - ConvTranspose2d: weight/grad 本身已经是 [in, out/groups, kh, kw]，in 在 dim=0
        """
        W = layer.weight
        G = None

        if (bank_tensor_row is not None) or (bank_tensor_col is not None):
            # 更偏向“以输入通道为行”的视角，如果提供了 col-bank 就直接用；
            # 否则退回 row-bank，再做 permute/转置。
            Gw = bank_tensor_col if bank_tensor_col is not None else bank_tensor_row
        else:
            Gw = None if (W.grad is None) else W.grad

        if use_transposed:
            # ConvTranspose2d 的“输入通道为主轴”本就位于 dim=0；Linear 的转置输入视角用 flatten(1)
            w_in = W.data.flatten(1)
            g_in = None if Gw is None else Gw.detach().flatten(1)
        else:
            if isinstance(layer, nn.Linear):
                w_in = W.data.transpose(0, 1).flatten(1)
                g_in = None if Gw is None else Gw.detach().transpose(0, 1).flatten(1)
            else:
                # Conv2d: [out, in, kh, kw] -> [in, out, kh, kw]
                w_in = W.data.permute(1, 0, 2, 3).flatten(1)
                if Gw is None:
                    g_in = None
                else:
                    # 若 Gw 来自 row_bank（[out, in, kh, kw]），需 permute；若来自 col_bank（[in, out, kh, kw]），直接用
                    if Gw.shape == W.shape:  # row-like
                        g_in = Gw.detach().permute(1, 0, 2, 3).flatten(1)
                    else:
                        g_in = Gw.detach().flatten(1)  # col-like 已经是 [in, out, kh, kw]
        return w_in, g_in

    @torch.no_grad()
    def __call__(self, group: Group, ch_groups: int = 1):
        group_imp = []
        group_idxs = []

        for i, (dep, idxs) in enumerate(group):
            # 兼容 idxs 可能是 list
            if hasattr(idxs, "sort"):
                idxs.sort()
            layer, prune_fn = self._get_layer_and_handler(dep)
            if layer is None or prune_fn is None:
                continue
            if not isinstance(layer, tuple(self.target_types)):
                continue

            # 判定 transposed 语义
            is_transposed = bool(getattr(layer, "transposed", False))

            # ========== 输出通道 ==========
            if prune_fn in [function.prune_conv_out_channels, function.prune_linear_out_channels]:
                # 选择银行梯度（优先 row_bank），否则用 .grad
                bank_tensor = self.bank_row.get(layer, None)

                w_sel, g_sel = self._weight_and_grad_or_bank_out(layer, idxs, bank_tensor=bank_tensor, use_transposed=is_transposed)
                if g_sel is None:
                    # 没有梯度可用，跳过该 dep（保持和 Disturb 行为一致）
                    continue

                local_imp = (w_sel * g_sel).abs().sum(1)

                # Disturb 的分组齐次化启发式
                if ch_groups > 1:
                    local_imp = local_imp.view(ch_groups, -1).sum(0)
                    local_imp = local_imp.repeat(ch_groups)

                group_imp.append(local_imp)
                group_idxs.append(group[i].root_idxs)

            # ========== 输入通道 ==========
            elif prune_fn in [function.prune_conv_in_channels, function.prune_linear_in_channels]:
                # 对于“按输入通道”为主的视角，优先使用 col_bank；没有则回退 row_bank/grad
                bank_row = self.bank_row.get(layer, None)
                bank_col = self.bank_col.get(layer, None)

                w_in, g_in = self._weight_and_grad_or_bank_in(layer, prune_fn, bank_tensor_row=bank_row, bank_tensor_col=bank_col, use_transposed=is_transposed)
                if g_in is None:
                    continue

                # Disturb 的特殊对齐：非组卷积接组卷积时按 ch_groups 重新排列
                if (ch_groups > 1) and (prune_fn == function.prune_conv_in_channels) and getattr(layer, "groups", 1) == 1:
                    # 参考 Disturb 中的写法：依赖于上一项 group_imp[0] 的长度进行 reshape 对齐
                    if len(group_imp) > 0:
                        base_len = group_imp[0].shape[0]
                        w_in = w_in.view(w_in.shape[0] // base_len, base_len, w_in.shape[1]).transpose(0, 1).flatten(1)
                        g_in = g_in.view(g_in.shape[0] // base_len, base_len, g_in.shape[1]).transpose(0, 1).flatten(1)

                local_imp = (w_in * g_in).abs().sum(1)

                if ch_groups > 1:
                    # 若形状对齐，就做分组聚合+repeat（保持 Disturb 行为）
                    if len(group_imp) > 0 and len(local_imp) == len(group_imp[0]):
                        local_imp = local_imp.view(ch_groups, -1).sum(0)
                        local_imp = local_imp.repeat(ch_groups)

                # 取需要剪的 idxs
                local_imp = local_imp[idxs]
                group_imp.append(local_imp)
                group_idxs.append(group[i].root_idxs)

            # ========== BatchNorm（与 Disturb 一致，基于 grad）==========
            elif prune_fn == function.prune_batchnorm_out_channels:
                if isinstance(layer, nn.modules.batchnorm._BatchNorm) and layer.affine:
                    if layer.weight is None or layer.weight.grad is None:
                        continue
                    g = layer.weight.grad.data[idxs]
                    local_imp = g.abs().pow(self.p)
                    if ch_groups > 1:
                        local_imp = local_imp.view(ch_groups, -1).sum(0)
                        local_imp = local_imp.repeat(ch_groups)
                    group_imp.append(local_imp)
                    group_idxs.append(group[i].root_idxs)

        if len(group_imp) == 0:
            return None

        # 组内聚合与归一化：使用基类实现，行为与 tp 内建保持一致
        reduced = self._reduce(group_imp, group_idxs)
        reduced = self._normalize(reduced, self.normalizer)
        return reduced


class GroupHessianImportance(GroupMagnitudeImportance):
    """Grouped Optimal Brain Damage:
    https://proceedings.neurips.cc/paper/1989/hash/6c9882bbac1c7093bd25041881277658-Abstract.html

    Example:

         It accepts a group as inputs, and return a 1-D tensor with the same length as the number of channels.
         All groups must be pruned simultaneously and thus their importance should be accumulated across channel groups.

         ```python
             inputs, labels = ...
             DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))
             scorer = GroupHessianImportance()
             scorer.zero_grad() # clean the acuumulated gradients if necessary
             loss = loss_fn(model(inputs), labels, reduction='none') # compute loss for each sample
             for l in loss:
                 model.zero_grad() # clean the model gradients
                 l.backward(retain_graph=True) # compute gradients for each sample
                 scorer.accumulate_grad(model) # accumulate gradients of each sample
             group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )
             imp_score = scorer(group)
             #imp_score is a 1-D tensor with length 3 for channels [2, 6, 9]
             min_score = imp_score.min()
         ```
    """

    def __init__(
        self,
        group_reduction: str = "mean",
        normalizer: str = "mean",
        bias=False,
        target_types: list = [nn.modules.conv._ConvNd, nn.Linear, nn.modules.batchnorm._BatchNorm, nn.modules.LayerNorm],
    ):
        self.group_reduction = group_reduction
        self.normalizer = normalizer
        self.target_types = target_types
        self.bias = bias
        self._accu_grad = {}
        self._counter = {}

    def zero_grad(self):
        self._accu_grad = {}
        self._counter = {}

    def accumulate_grad(self, model):
        for _, p in model.named_parameters():
            if p.grad is None:
                continue
            if p not in self._accu_grad:
                self._accu_grad[p] = p.grad.detach().clone().pow(2)
                self._counter[p] = 1
            else:
                self._accu_grad[p] += p.grad.detach().clone().pow(2)
                self._counter[p] += 1

    @torch.no_grad()
    def __call__(self, group):
        group_imp = []
        group_idxs = []

        if len(self._accu_grad) > 0:  # fill gradients so that we can re-use the implementation for Taylor
            for p, g in self._accu_grad.items():
                p.grad.data = g / self._counter[p]
            self.zero_grad()

        for i, (dep, idxs) in enumerate(group):
            idxs.sort()
            layer = dep.target.module
            prune_fn = dep.handler
            root_idxs = group[i].root_idxs

            if not isinstance(layer, tuple(self.target_types)):
                continue

            if prune_fn in [
                function.prune_conv_out_channels,
                function.prune_linear_out_channels,
            ]:
                if layer.weight.grad is not None:
                    if hasattr(layer, "transposed") and layer.transposed:
                        w = layer.weight.data.transpose(1, 0)[idxs].flatten(1)
                        h = layer.weight.grad.data.transpose(1, 0)[idxs].flatten(1)
                    else:
                        w = layer.weight.data[idxs].flatten(1)
                        h = layer.weight.grad.data[idxs].flatten(1)

                    local_imp = (w**2 * h).sum(1)
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

                if self.bias and layer.bias is not None and layer.bias.grad is not None:
                    b = layer.bias.data[idxs]
                    h = layer.bias.grad.data[idxs]
                    local_imp = b**2 * h
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

            # Conv in_channels
            elif prune_fn in [
                function.prune_conv_in_channels,
                function.prune_linear_in_channels,
            ]:
                if layer.weight.grad is not None:
                    if hasattr(layer, "transposed") and layer.transposed:
                        w = (layer.weight).flatten(1)
                        h = (layer.weight.grad).flatten(1)
                    else:
                        w = (layer.weight).transpose(0, 1).flatten(1)
                        h = (layer.weight.grad).transpose(0, 1).flatten(1)

                    local_imp = (w**2 * h).sum(1)
                    if prune_fn == function.prune_conv_in_channels and layer.groups != layer.in_channels and layer.groups != 1:
                        local_imp = local_imp.repeat(layer.groups)
                    local_imp = local_imp[idxs]
                    group_imp.append(local_imp)
                    group_idxs.append(root_idxs)

            # BN
            elif prune_fn == function.prune_batchnorm_out_channels:
                if layer.affine:
                    if layer.weight.grad is not None:
                        w = layer.weight.data[idxs]
                        h = layer.weight.grad.data[idxs]
                        local_imp = w**2 * h
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)

                    if self.bias and layer.bias is not None and layer.bias.grad is None:
                        b = layer.bias.data[idxs]
                        h = layer.bias.grad.data[idxs]
                        local_imp = (b**2 * h).abs()
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)

            # LN
            elif prune_fn == function.prune_layernorm_out_channels:
                if layer.elementwise_affine:
                    if layer.weight.grad is not None:
                        w = layer.weight.data[idxs]
                        h = layer.weight.grad.data[idxs]
                        local_imp = w**2 * h
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)
                    if self.bias and layer.bias is not None and layer.bias.grad is not None:
                        b = layer.bias.data[idxs]
                        h = layer.bias.grad.data[idxs]
                        local_imp = b**2 * h
                        group_imp.append(local_imp)
                        group_idxs.append(root_idxs)

        if len(group_imp) == 0:  # skip groups without parameterized layers
            return None
        group_imp = self._reduce(group_imp, group_idxs)
        group_imp = self._normalize(group_imp, self.normalizer)
        return group_imp


class RelPosManager:
    """管理相对位置编码：capture()缓存当前值；delete()从模型移除；restore()按'无损/全0'两分法恢复。"""

    def __init__(self):
        self._cache = {}  # {block_idx:{"has":bool,"h":Tensor or None,"w":Tensor or None}}

    @torch.no_grad()
    def capture_current(self, model: nn.Module):
        enc = _get_encoder(model)
        self._cache.clear()
        for i in range(enc.depth):
            attn: Attention = enc.blocks[i].attn
            has = getattr(attn, "use_rel_pos", False) and hasattr(attn, "rel_pos_h") and hasattr(attn, "rel_pos_w")
            if has:
                # 缓存到CPU，避免持有大显存；恢复时再搬回设备/精度
                h = attn.rel_pos_h.detach().cpu().clone()
                w = attn.rel_pos_w.detach().cpu().clone()
                self._cache[i] = {"has": True, "h": h, "w": w}
            else:
                self._cache[i] = {"has": False, "h": None, "w": None}
        return self

    @torch.no_grad()
    def delete(self, model: nn.Module):
        enc = _get_encoder(model)
        for i in range(enc.depth):
            attn: Attention = enc.blocks[i].attn
            # 关闭并移除参数，保证构图/剪枝不把它们收进依赖
            attn.use_rel_pos = False
            for name in ("rel_pos_h", "rel_pos_w"):
                if hasattr(attn, name):
                    delattr(attn, name)
        return self

    @torch.no_grad()
    def restore(self, model: nn.Module):
        enc = _get_encoder(model)
        for i in range(enc.depth):
            attn: Attention = enc.blocks[i].attn
            device = attn.q.weight.device
            dtype = attn.q.weight.dtype
            # 新形状由当前 head_dim 与输入尺寸决定： (2*h-1, head_dim) / (2*w-1, head_dim)
            head_dim = attn.q.out_features // attn.num_heads
            h, w = attn.input_size
            H_shape = (2 * h - 1, head_dim)
            W_shape = (2 * w - 1, head_dim)
            cached = self._cache.get(i, {"has": False, "h": None, "w": None})
            use_cached = (
                cached["has"]
                and cached["h"] is not None
                and cached["w"] is not None
                and tuple(cached["h"].shape) == H_shape
                and tuple(cached["w"].shape) == W_shape
            )
            if use_cached:
                rel_h = nn.Parameter(cached["h"].to(device=device, dtype=dtype))
                rel_w = nn.Parameter(cached["w"].to(device=device, dtype=dtype))
            else:
                # 形状不匹配 ⇒ 全0重建（你要求的第二策略）
                rel_h = nn.Parameter(torch.zeros(*H_shape, device=device, dtype=dtype))
                rel_w = nn.Parameter(torch.zeros(*W_shape, device=device, dtype=dtype))
            attn.rel_pos_h = rel_h
            attn.rel_pos_w = rel_w
            attn.use_rel_pos = True
            attn.scale = head_dim**-0.5  # 与 head_dim 同步
        return self

    def clear(self):
        self._cache.clear()
        return self


@torch.no_grad()
def delete_pos_embedding(model: nn.Module) -> nn.Module:
    enc = _get_encoder(model)
    for i in range(enc.depth):
        attn: Attention = enc.blocks[i].attn
        attn.use_rel_pos = False
        for attr in ("rel_pos_h", "rel_pos_w"):
            if hasattr(attn, attr):
                delattr(attn, attr)
    return model


@torch.no_grad()
def restore_pos_embedding(model: nn.Module) -> nn.Module:
    enc = _get_encoder(model)
    for i in range(enc.depth):
        attn: Attention = enc.blocks[i].attn
        device = attn.q.weight.device
        dtype = attn.q.weight.dtype
        head_dim = attn.q.out_features // attn.num_heads
        h, w = attn.input_size
        attn.scale = head_dim**-0.5
        attn.use_rel_pos = True
        attn.rel_pos_h = nn.Parameter(torch.zeros(2 * h - 1, head_dim, device=device, dtype=dtype))
        attn.rel_pos_w = nn.Parameter(torch.zeros(2 * w - 1, head_dim, device=device, dtype=dtype))
    return model


def _get_encoder(obj: nn.Module) -> ImageEncoderViT:
    if isinstance(obj, ImageEncoderViT):
        return obj
    if hasattr(obj, "image_encoder"):
        return obj.image_encoder
    raise TypeError("Expected ImageEncoderViT or an object with .image_encoder")
