import torch
from typing import List

def _pairwise_sqdist_spatial(
    Z_bsd: torch.Tensor,   # [B, S, D]
    P_cmd: torch.Tensor,   # [C, M, D]
) -> torch.Tensor:
    """
    返回 d2[b,s,c,m] = ||Z[b,s,:] - P[c,m,:]||_2^2
    形状：[B, S, C, M]
    """
    z2 = (Z_bsd ** 2).sum(dim=-1, keepdim=True)                # [B,S,1]
    p2 = (P_cmd ** 2).sum(dim=-1)                              # [C,M]
    cross = torch.einsum("bsd,cmd->bscm", Z_bsd, P_cmd)        # [B,S,C,M]
    d2 = z2[..., None] + p2[None, None, :, :] - 2.0 * cross
    return d2.clamp_min_(0.0)

def _protopnet_activation_from_d2(
    d2_bscm: torch.Tensor,  # [B, S, C, M]
    eps: float = 1e-4
) -> torch.Tensor:
    """
    g_p(z) = log((d2 + 1) / (d2 + eps))，对 d2 单调递减
    输出形状同输入：[B, S, C, M]
    """
    eps = float(max(min(eps, 0.99), 1e-8))
    return torch.log((d2_bscm + 1.0) / (d2_bscm + eps))


def _topgamma_mask(max_sim_bs: torch.Tensor, gamma: float) -> torch.Tensor:
    """
    给定每个样本每个位置的相似度 max_sim_bs [B,S]，返回二值掩码 [B,S]（bool）。
    每个样本内选取前 k=ceil(S*gamma) 个位置为 True。
    """
    B, S = max_sim_bs.shape
    if gamma >= 1.0:
        return torch.ones_like(max_sim_bs, dtype=torch.bool)
    if gamma <= 0.0:
        return torch.zeros_like(max_sim_bs, dtype=torch.bool)
    k = max(1, int(round(S * gamma)))
    # topk 索引
    _, idx = torch.topk(max_sim_bs, k=k, dim=1, largest=True, sorted=False)
    mask = torch.zeros_like(max_sim_bs, dtype=torch.bool)
    arange_b = torch.arange(B, device=max_sim_bs.device).unsqueeze(1)  # [B,1]
    mask[arange_b, idx] = True
    return mask



def sim_dis_loss(
    list_of_protofeatures: List[torch.Tensor],  # 每层: [B,S,D]
    current_proto_list:   List[torch.Tensor],   # 每层: [C,M,D] (p^t)
    old_proto_list:       List[torch.Tensor],   # 每层: [C,M,D] (p^{t-1})
    gamma: float = 0.1,                         # top-γ 像素比例
    reduction: str = "mean",                    # "mean" | "sum" | "none"
    eps: float = 1e-4
) -> torch.Tensor:
    """
    L_IR = sum_{i,j} | sim(p^{t-1}, z_{i,j}^t) - sim(p^t, z_{i,j}^t) | * S_{i,j}
    其中 sim 使用你给出的 ProtoPNet 激活 g_p(d2)。
    多层按平均合并（可改为求和）。
    """
    assert len(list_of_protofeatures) == len(current_proto_list) == len(old_proto_list), \
        "All three lists must have the same length (granularity levels)."

    per_layer_vals = []
    for Z_bsd, P_cur_cmd, P_old_cmd in zip(list_of_protofeatures, current_proto_list, old_proto_list):
        # [B,S,C,M] 的平方距离
        d2_old = _pairwise_sqdist_spatial(Z_bsd, P_old_cmd)
        d2_cur = _pairwise_sqdist_spatial(Z_bsd, P_cur_cmd)
        # 相似度（值越大越相似）
        sim_old = _protopnet_activation_from_d2(d2_old, eps=eps)  # [B,S,C,M]
        sim_cur = _protopnet_activation_from_d2(d2_cur, eps=eps)  # [B,S,C,M]

        # 对所有原型取最大响应 -> [B,S]
        max_old_cs, _ = sim_old.max(dim=3)   # [B,S,C]
        max_cur_cs, _ = sim_cur.max(dim=3)   # [B,S,C]
        max_old, _ = max_old_cs.max(dim=2)   # [B,S]
        max_cur, _ = max_cur_cs.max(dim=2)   # [B,S]

        # 基于旧模型的相似度构造掩码 S
        mask = _topgamma_mask(max_old, gamma=gamma)  # [B,S] bool

        # 绝对差并掩码
        diff = (max_old - max_cur).abs()             # [B,S]
        masked = diff * mask.to(diff.dtype)          # [B,S]

        if reduction == "none":
            per_layer_vals.append(masked)            # 保留 [B,S]
        else:
            if reduction == "sum":
                val = masked.sum()
            elif reduction == "mean":
                denom = mask.sum().clamp_min(1)      # 避免除 0
                val = masked.sum() / denom
            else:
                raise ValueError(f"Unsupported reduction: {reduction}")
            per_layer_vals.append(val)

    if reduction == "none":
        return per_layer_vals  # list of [B,S]
    else:
        # 多层平均（按需可改为纯求和）
        out = sum(per_layer_vals) / max(1, len(per_layer_vals))
        return out