import torch
import torch.nn.functional as F
from typing import List, Sequence, Union

def logits_dis_loss(
    new_logits_list: List[torch.Tensor],      # [H]，每个 [B, C_total]
    old_logits_list: List[torch.Tensor],      # [H]，每个 [B, C_total]
    old_class_indexes: List[torch.Tensor],    # [H]，每个为 LongTensor（该 head 的旧类列索引，形如 tensor([..])，可为空）
    temperature: float = 6.0,
    loss_type: str = "kl",                    # "ce" 或 "kl"
    reduction: str = "mean",                  # "mean" | "sum" | "none"
) -> torch.Tensor:
    """
    LwF 风格蒸馏损失：对每个 head，用其对应的旧类索引子集进行蒸馏。
    - new_logits_list/old_logits_list/old_class_indexes 长度均为 H。
    - old_class_indexes[h] 是该 head 的列索引（LongTensor，长度可为 0 表示该层无旧类）。
    - 当某 head 无旧类时，该 head 的损失记为 0。
    """
    H = len(new_logits_list)
    assert H == len(old_logits_list) == len(old_class_indexes), "new/old/index 列表长度需一致"
    assert loss_type in ("ce", "kl")
    assert reduction in ("mean", "sum", "none")

    def kd_ce(s_logits, t_logits, T):
        log_p_s = F.log_softmax(s_logits / T, dim=-1)
        p_t = F.softmax(t_logits / T, dim=-1).detach()
        loss_vec = -(p_t * log_p_s).sum(dim=-1)  # [B]
        return loss_vec * (T * T)

    def kd_kl(s_logits, t_logits, T):
        log_p_s = F.log_softmax(s_logits / T, dim=-1)
        p_t = F.softmax(t_logits / T, dim=-1).detach()
        loss_vec = (p_t * (torch.log(p_t + 1e-12) - log_p_s)).sum(dim=-1)  # [B]
        return loss_vec * (T * T)

    per_head = []
    for h, (s_all, t_all, idx_h) in enumerate(zip(new_logits_list, old_logits_list, old_class_indexes)):
        if not torch.is_tensor(idx_h):
            idx_h = torch.as_tensor(idx_h, dtype=torch.long, device=s_all.device)
        else:
            idx_h = idx_h.to(device=s_all.device, dtype=torch.long)

        if idx_h.numel() == 0:
            # 该 head 没有旧类，损失记 0
            if reduction == "none":
                B = s_all.size(0)
                per_head.append(torch.zeros(B, device=s_all.device, dtype=s_all.dtype))
            else:
                per_head.append(torch.tensor(0.0, device=s_all.device, dtype=s_all.dtype))
            continue

        # 选取旧类子集
        s = s_all.index_select(-1, idx_h)  # [B, C_old_h]
        t = t_all.index_select(-1, idx_h)  # [B, C_old_h]

        loss_vec = kd_ce(s, t, temperature) if loss_type == "ce" else kd_kl(s, t, temperature)

        if reduction == "none":
            per_head.append(loss_vec)          # [B]
        elif reduction == "sum":
            per_head.append(loss_vec.sum())    # 标量
        else:  # "mean"
            per_head.append(loss_vec.mean())   # 标量

    if reduction == "none":
        # [H, B]
        return torch.stack(per_head, dim=0)
    else:
        # 聚合 head
        return torch.stack(per_head).sum()
    


if __name__ == "__main__":
    torch.manual_seed(0)

    # 假设有 3 个 heads，每个 head 的类别总数分别为:
    C = [8, 6, 10]
    B = 4  # batch size
    H = len(C)

    # 构造 teacher(old)/student(new) 的随机 logits（通常来自模型前向）
    new_logits_list = [torch.randn(B, c) for c in C]
    old_logits_list = [torch.randn(B, c) for c in C]

    # 按层的旧类索引（每层不同，且允许为空）
    # head0: 旧类为 [0,2,3], head1: 空(无旧类), head2: [1,4,9,0]
    old_class_indexes = [
        torch.tensor([0, 2, 3], dtype=torch.long),
        torch.tensor([], dtype=torch.long),
        torch.tensor([1, 4, 9, 0], dtype=torch.long),
    ]

    # 计算不同 reduction 与 loss_type 的损失
    for loss_type in ("ce", "kl"):
        for reduction in ("mean", "sum", "none"):
            loss = logits_dis_loss(
                new_logits_list,
                old_logits_list,
                old_class_indexes,
                temperature=2.0,
                loss_type=loss_type,
                reduction=reduction,
            )
            print(f"loss_type={loss_type:>2}, reduction={reduction:>4} -> shape={tuple(loss.shape)}, value=\n{loss}\n")