import torch
import torch.nn as nn
from typing import List, Optional, Union
from nltk.tree import Tree
import torch.nn.functional as F
import pickle

def get_next_level_children(tree, node_name):
    """
    获取某节点的下一层孩子节点的标签（确保是 Lm+1_* 格式）。
    
    Args:
        tree (dict): 树结构，嵌套字典表示。
        target_label (str): 目标节点的标签，如 "Lm_n"。
        
    Returns:
        list: 下一层孩子节点的标签（Lm+1_* 格式）。
    """
    level_index = int(node_name.split("_")[0][1:])  # 提取层次编号 m
    children_names = []
    labels = []
    # 遍历整个树的每个节点
    for subtree in tree:
        # 如果是列表并且第一个元素是目标节点
        if isinstance(subtree, Tree) and subtree.label() == node_name:
            # 获取直接子节点的标签
            children_names = [
                child.label() if isinstance(child, Tree) else child for child in subtree
            ]
            # 过滤出符合 Lm+1_* 格式的子节点
            children_names = [
                child_name for child_name in children_names
                if int(child_name.split("_")[0][1:]) == level_index + 1
            ]
            for child_name in children_names:
                if int(child_name.split("_")[0][1:]) == level_index + 1:
                    labels.append(int(child_name.split("_")[1]))
            break  # 找到目标节点后停止搜索
        
        # 如果当前节点是子树，递归查找
        if isinstance(subtree, Tree):
            children_names, labels = get_next_level_children(subtree, node_name)
            if children_names:  # 如果找到了结果，直接返回
                break

    # 返回找到的子节点标签列表
    return children_names, labels
        

def get_m_level_nodes(tree: Tree, m: int) -> List[str]:
    """
    获取树中所有层级为 m 的节点标签（格式为 Lm_n 的字符串）。

    Args:
        tree (Tree): 树对象。
        m (int): 目标层级（从 0 开始计数）。

    Returns:
        List[str]: 第 m 层节点标签（格式为 Lm_n 的字符串）。
    """
    if m < 0:
        raise ValueError("层级 m 必须是非负整数。")
    
    nodes = []
    
    for subtree in tree:
        if isinstance(subtree, Tree):
            # 如果当前节点是 Tree 对象，递归获取子树的节点
            if int(subtree.label().split("_")[0][1:]) == m:
                nodes.append(int(subtree.label().split("_")[1]))
            nodes.extend(get_m_level_nodes(subtree, m))
        else:
            # 如果是叶子节点，检查其层级
            if int(subtree.split("_")[0][1:]) == m:
                nodes.append(int(subtree.split("_")[1]))
    return nodes    

def js_divergence(p, q):
    """
    计算两个概率分布之间的 Jensen-Shannon 散度。

    Args:
        p (torch.Tensor): 第一个概率分布张量 (形状: [batch_size, num_classes])。
        q (torch.Tensor): 第二个概率分布张量 (形状: [batch_size, num_classes])。

    Returns:
        torch.Tensor: 每个样本的 Jensen-Shannon 散度 (形状: [batch_size])。
    """
    # 确保两个分布都已经归一化（概率分布）
    p = F.softmax(p, dim=-1)
    q = F.softmax(q, dim=-1)

    # 计算均值分布 M
    m = 0.5 * (p + q)

    # 计算 KL(P || M) 和 KL(Q || M)
    kl_pm = torch.sum(p * torch.log(p / m), dim=-1)
    kl_qm = torch.sum(q * torch.log(q / m), dim=-1)

    # JS散度
    js = 0.5 * (kl_pm + kl_qm)
    return js


# def get_leaf_paths(tree: Tree, path=None): # 获取树的所有叶子节点及其路径
#     if path is None:
#         path = []
#     if isinstance(tree, str) and tree.startswith("L"):  # 叶子节点
#         return {tree: path + [tree]}
#     leaf_paths = {}
#     for subtree in tree:
#         leaf_paths.update(get_leaf_paths(subtree, path + [tree.label()]))
#     return leaf_paths

# # 补全 labels_list 的函数
# def complete_labels(labels_list: List[torch.Tensor], label_tree: Tree) -> List[torch.Tensor]:
#     """
#     根据 label_tree 和 leaf_paths 补全 labels_list 中的缺失值（-1）。

#     Args:
#         labels_list: List[torch.Tensor]，每一行表示一个层次的标签，-1 为缺失值。
#         label_tree: Tree，表示层次结构的树。

#     Returns:
#         List[torch.Tensor]: 补全后的 labels_list。
#     """
#     # 获取所有叶子节点及其路径
#     leaf_paths = get_leaf_paths(label_tree)  # {leaf_name: [path_to_leaf]}

#     # 获取叶子节点编号到路径的映射
#     leaf_to_path = {}
#     for leaf, path in leaf_paths.items():
#         # 从路径中提取叶子节点的编号
#         leaf_to_path[int(leaf.split("_")[1])] = path

#     # 获取最深层次
#     num_levels = len(labels_list)

#     # 遍历 labels_list 的最后一行，尝试匹配叶子节点
#     last_layer = labels_list[-1]  # 最后一层的标签
#     for i, label in enumerate(last_layer):
#         if label == -1:
#             continue  # 跳过缺失值

#         # 找到对应叶子节点的路径
#         path = leaf_to_path[int(label.item())]  # 根据叶子编号找到路径

#         # 补全上层的标签（包括 root 节点）
#         for level, node in enumerate(path):  # 包括 root 节点
#             level_index = int(node.split("_")[1]) if "_" in node else 0  # 根节点编号为 0
#             if labels_list[level-1][i] == -1:  # 仅在缺失值时补全
#                 labels_list[level-1][i] = level_index

#     return labels_list




from typing import List, Dict, Optional, Tuple
import torch
from collections import deque
import re

_LEVEL_ID_RE = re.compile(r"^L(\d+)_(\d+)$")

def _parse_level_id(name: str) -> Optional[Tuple[int, int]]:
    """
    解析节点名，期望形如 'L{level}_{id}'，返回 (level, id)；否则返回 None。
    """
    m = _LEVEL_ID_RE.match(name)
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))

def _build_maps_all_nodes(label_tree) -> Tuple[Dict[int, str], Dict[str, List[str]], set]:
    """
    遍历整棵树，构建：
    - id_to_name: id -> 'Lm_id'
    - name_to_prefix_path: 'Lm_id' -> ['L1_x', 'L2_y', ..., 'Lm_id']（不包含 root 等无法解析层号的名字）
    - leaf_names: 叶子节点名集合（仅限可解析为 'Lm_id' 的节点）
    """
    id_to_name: Dict[int, str] = {}
    name_to_prefix_path: Dict[str, List[str]] = {}
    leaf_names: set = set()

    q = deque()
    # 队列元素: (node, prefix_names)；prefix_names 仅包含已解析 'Lm_id' 的名字
    q.append((label_tree, []))

    while q:
        node, prefix = q.popleft()
        name = getattr(node, "name", "")
        parsed = _parse_level_id(name)

        if parsed is not None:
            level_m, id_n = parsed
            cur = prefix + [name]
            id_to_name[id_n] = name
            if name not in name_to_prefix_path:
                name_to_prefix_path[name] = cur
        else:
            cur = prefix  # root 或无法解析层号的名字，不计入路径

        children = getattr(node, "children", [])
        if len(children) == 0:
            # 叶子：仅记录可解析的名字
            if parsed is not None:
                leaf_names.add(name)
        else:
            for ch in children:
                q.append((ch, cur))

    return id_to_name, name_to_prefix_path, leaf_names

def complete_labels(labels_list: List[torch.Tensor], label_tree) -> List[torch.Tensor]:
    """
    根据 label_tree 补全 labels_list 中的缺失值（-1）。

    约定与策略：
    - 节点名遵循 'L{m}_{id}'，用 m 与 labels_list 的第 m-1 行对齐（第0行是 L1 层）。
    - 仅向上补全祖先，不向下生成子节点。
    - 只在缺失(-1)时填充，不覆盖已有值。
    - 对树的不完整/缺失保持稳健：若某 id 不在树中则跳过。
    """
    assert len(labels_list) > 0
    D = len(labels_list)
    B = labels_list[0].shape[0]
    assert all(t.shape[0] == B for t in labels_list), "All rows must have same batch size"

    # 构建整树索引，避免仅从叶子推导造成的遗漏
    id_to_name, name_to_prefix_path, _leaf_names = _build_maps_all_nodes(label_tree)

    out = [t.clone() for t in labels_list]

    # 策略：对每一列，扫描各层已知标签；对每个已知节点，向上补齐其祖先（按层号对齐行）
    for i in range(B):
        for row in range(D):
            val = int(out[row][i].item())
            if val == -1:
                continue

            node_name = id_to_name.get(val)
            if node_name is None:
                # 树中无此 id（树缺失或 id 空间不一致），跳过
                continue

            prefix = name_to_prefix_path.get(node_name)
            if not prefix:
                continue

            # 将前缀路径中的每个 'Lm_id' 映射到第 m-1 行，仅在 -1 时填充
            for name in prefix:
                parsed = _parse_level_id(name)
                if parsed is None:
                    continue
                level_m, id_n = parsed
                row_idx = level_m - 1  # 第 m 层 -> 第 m-1 行（与你的数据示例一致）
                if 0 <= row_idx < D and int(out[row_idx][i].item()) == -1:
                    out[row_idx][i] = id_n

            # 不向下补：prefix 只到当前节点为止

    return out






from typing import List, Dict, Optional, Tuple
import torch
from collections import deque
import re

_LEVEL_ID_RE = re.compile(r"^L(\d+)_(\d+)$")

def _parse_level_id(name: str) -> Optional[Tuple[int, int]]:
    """
    解析节点名，期望形如 'L{level}_{id}'，返回 (level, id)；否则返回 None。
    """
    m = _LEVEL_ID_RE.match(name)
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))

def _build_maps_all_nodes(label_tree) -> Tuple[Dict[int, str], Dict[str, List[str]], set]:
    """
    遍历整棵树，构建：
    - id_to_name: id -> 'Lm_id'
    - name_to_prefix_path: 'Lm_id' -> ['L1_x', 'L2_y', ..., 'Lm_id']（不包含 root 等无法解析层号的名字）
    - leaf_names: 叶子节点名集合（仅限可解析为 'Lm_id' 的节点）
    """
    id_to_name: Dict[int, str] = {}
    name_to_prefix_path: Dict[str, List[str]] = {}
    leaf_names: set = set()

    q = deque()
    # 队列元素: (node, prefix_names)；prefix_names 仅包含已解析 'Lm_id' 的名字
    q.append((label_tree, []))

    while q:
        node, prefix = q.popleft()
        name = getattr(node, "name", "")
        parsed = _parse_level_id(name)

        if parsed is not None:
            level_m, id_n = parsed
            cur = prefix + [name]
            id_to_name[id_n] = name
            if name not in name_to_prefix_path:
                name_to_prefix_path[name] = cur
        else:
            cur = prefix  # root 或无法解析层号的名字，不计入路径

        children = getattr(node, "children", [])
        if len(children) == 0:
            # 叶子：仅记录可解析的名字
            if parsed is not None:
                leaf_names.add(name)
        else:
            for ch in children:
                q.append((ch, cur))

    return id_to_name, name_to_prefix_path, leaf_names

def complete_labels(labels_list: List[torch.Tensor], label_tree) -> List[torch.Tensor]:
    """
    根据 label_tree 补全 labels_list 中的缺失值（-1）。

    约定与策略：
    - 节点名遵循 'L{m}_{id}'，用 m 与 labels_list 的第 m-1 行对齐（第0行是 L1 层）。
    - 仅向上补全祖先，不向下生成子节点。
    - 只在缺失(-1)时填充，不覆盖已有值。
    - 对树的不完整/缺失保持稳健：若某 id 不在树中则跳过。
    """
    assert len(labels_list) > 0
    D = len(labels_list)
    B = labels_list[0].shape[0]
    assert all(t.shape[0] == B for t in labels_list), "All rows must have same batch size"

    # 构建整树索引，避免仅从叶子推导造成的遗漏
    id_to_name, name_to_prefix_path, _leaf_names = _build_maps_all_nodes(label_tree)

    out = [t.clone() for t in labels_list]

    # 策略：对每一列，扫描各层已知标签；对每个已知节点，向上补齐其祖先（按层号对齐行）
    for i in range(B):
        for row in range(D):
            val = int(out[row][i].item())
            if val == -1:
                continue

            node_name = id_to_name.get(val)
            if node_name is None:
                # 树中无此 id（树缺失或 id 空间不一致），跳过
                continue

            prefix = name_to_prefix_path.get(node_name)
            if not prefix:
                continue

            # 将前缀路径中的每个 'Lm_id' 映射到第 m-1 行，仅在 -1 时填充
            for name in prefix:
                parsed = _parse_level_id(name)
                if parsed is None:
                    continue
                level_m, id_n = parsed
                row_idx = level_m - 1  # 第 m 层 -> 第 m-1 行（与你的数据示例一致）
                if 0 <= row_idx < D and int(out[row_idx][i].item()) == -1:
                    out[row_idx][i] = id_n

            # 不向下补：prefix 只到当前节点为止

    return out






def margin_loss(logits, labels, margin=3.0):
    """
    计算损失函数 L_m。
    
    Args:
        logits (torch.Tensor): 模型的输出 logits，形状为 (N, D)。
        labels (torch.Tensor): 样本的标签，形状为 (N,)。
        margin (float): 超参数 m，默认值为 1.0。
    Returns:
        torch.Tensor: 损失值。
    """
     # 计算 softmax 概率分布
    probs = F.softmax(logits, dim=-1)
    
    # 初始化损失
    loss = 0.0
    # 获取所有类别
    unique_labels = labels.unique()
        # 遍历每个类别
    for h in unique_labels:
        # 找到属于当前类别 h 的样本索引
        indices_h = (labels == h).nonzero(as_tuple=True)[0]
        
        # 如果当前类别的样本数少于 2，跳过
        if len(indices_h) < 2:
            continue
        
        # 获取当前类别的样本分布和标签
        probs_h = probs[indices_h]  # (N_h, D)
        labels_h = labels[indices_h]  # (N_h,)
        
        # 遍历当前类别的所有样本对
        for i in range(len(probs_h)):
            for j in range(len(probs_h)):
                if labels_h[i] != labels_h[j]:  # 只考虑标签不同的样本对
                    # 计算 JS 散度
                    js = js_divergence(probs_h[i], probs_h[j])
                    # 累加 max(0, margin - JS)
                    loss += torch.clamp(margin - js, min=0)
    
    return loss










def haf_loss(
    logits_list: List[torch.Tensor],
    labels_list: List[torch.Tensor],
    label_tree: Tree,
    margin: float = 3.0,
    aggregation: str = 'mean',  # 'mean' | 'sum' | 'individual'
) -> Union[torch.Tensor, List[torch.Tensor], None]:
    if len(logits_list) != len(labels_list):
        raise ValueError(f"logits_list (length {len(logits_list)}) and labels_list (length {len(labels_list)}) must have the same length.")
    current_criterion = nn.CrossEntropyLoss()

    # 若需要：labels 完全性
    labels_list = complete_labels(labels_list, label_tree)

    B = labels_list[0].shape[0]
    device = logits_list[0].device
    sample_losses = torch.zeros(B, dtype=torch.float, device=device)

    for i, (logits, labels) in enumerate(zip(logits_list, labels_list)):
        if labels.dtype != torch.long:
            labels = labels.long()

        valid_mask = labels != -1
        if valid_mask.sum() == 0:
            continue

        valid_logits = logits[valid_mask]        # [N_i, C_i]
        valid_labels = labels[valid_mask]        # [N_i]
        loss_dis = current_criterion(valid_logits, valid_labels)

        loss_margin = 0
        if i < len(logits_list) - 1:
            loss_margin = margin_loss(valid_logits, valid_labels, margin=margin)

        # 一致性：与下一层同一样本对齐（关键修正）
        loss_const = 0
        if i < len(logits_list) - 1:
            # 下一层的原始 logits/labels
            logits_next = logits_list[i + 1]     # [B, C_{i+1}]
            labels_next = labels_list[i + 1]

            # 两层同时有效的样本（在原 batch 维度上对齐）
            both_mask = (labels != -1) & (labels_next != -1)
            if both_mask.any():
                idx_all = torch.nonzero(both_mask, as_tuple=False).squeeze(1)  # 全局 batch 索引
                # 将当前层的有效 mask 里的“局部索引 j”映射回全局 batch 索引
                idx_valid = torch.nonzero(valid_mask, as_tuple=False).squeeze(1)

                # 建立从局部 j 到全局 b 的对应
                # 遍历局部 j，对应的全局 b = idx_valid[j]，仅在 b∈both_mask 时计算
                for j_local, b in enumerate(idx_valid.tolist()):
                    if not both_mask[b]:
                        continue
                    # 当前样本在层 i、i+1 的 logits
                    logit_1 = logits[b]          # [C_i]
                    logit_2 = logits_next[b]     # [C_{i+1}]

                    prob_1 = F.softmax(logit_1, dim=-1)
                    prob_2 = F.softmax(logit_2, dim=-1)

                    # node1：本层的列索引集合；node2：下一层的列索引集合（如需）
                    node1 = get_m_level_nodes(label_tree, i + 1)  # 列索引（与 prob_1 对齐）
                    # 构造 q：把下一层子节点概率聚合到父节点
                    p = prob_1[node1]
                    q = torch.zeros_like(p)
                    for k, node in enumerate(node1):
                        name = f"L{i+1}_{node}"
                        _, child_cols = get_next_level_children(label_tree, name)  # 需返回列索引列表
                        if len(child_cols) > 0:
                            q[k] = prob_2[child_cols].sum()
                    loss_const += js_divergence(p, q)

        # 将损失填回原始样本位置（保持你原有的加法逻辑）
        sample_losses[valid_mask] += loss_dis
        sample_losses[valid_mask] += loss_margin
        # 一致性这里是标量（累加了多个样本），为与原逻辑最接近，均匀分摊到当前层有效样本
        if valid_mask.sum() > 0 and isinstance(loss_const, torch.Tensor):
            sample_losses[valid_mask] += loss_const / valid_mask.sum()

    if aggregation == 'sum':
        return sample_losses.sum()
    elif aggregation == 'mean':
        return sample_losses.mean()
    elif aggregation == 'individual':
        return sample_losses.tolist()
    else:
        raise ValueError(f"Invalid aggregation method: '{aggregation}'. Choose 'mean', 'sum', or 'individual'.")



# --- Example Usage ---

if __name__ == '__main__':
    tree_fname = "./data/cifar100/cifar_100_tree.pkl"
    name = 'cifar'
    with open(tree_fname, "rb") as f:
        label_tree = pickle.load(f)
    print(label_tree)
    labels_list = [
        torch.tensor([0, -1, -1, -1, -1, -1]),  # 第一层
        torch.tensor([-1, -1, -1, -1, -1, -1]), # 第二层
        torch.tensor([9, 22, 2, 3, 4, 5])       # 第三层（叶子节点层）
    ]
    # completed_labels = complete_labels(labels_list, label_tree)
    # print(completed_labels)
    # children, labels = get_next_level_children(label_tree, "L3_15")
    # print(labels)  # 输出下一层孩子节点的标签
    # nodes = get_m_level_nodes(label_tree, 1)
    # print(nodes)
    loss=haf_loss(
        logits_list=[
            torch.randn(6, 100),  # 第一层 logits (batch_size=4, num_classes=3)
            torch.randn(6, 100),  # 第二层 logits (batch_size=4, num_classes=5)
            torch.randn(6, 100)   # 第三层 logits (batch_size=4, num_classes=6)
        ],
        labels_list=labels_list,
        label_tree=label_tree,
        margin=3.0,
        aggregation='mean'
    )