import torch
import torch.nn as nn
from typing import List, Optional, Union
from nltk.tree import Tree
import pickle
import re
from typing import List, Dict
import torch
from collections import deque

def get_leaf_paths(tree: Tree, path=None): # 获取树的所有叶子节点及其路径
    if path is None:
        path = []
    if isinstance(tree, str) and tree.startswith("L"):  # 叶子节点
        return {tree: path + [tree]}
    leaf_paths = {}
    for subtree in tree:
        leaf_paths.update(get_leaf_paths(subtree, path + [tree.label()]))
    return leaf_paths

# 补全 labels_list 的函数
# def complete_labels(labels_list: List[torch.Tensor], label_tree: Tree) -> List[torch.Tensor]:
#     """
#     根据 label_tree 和 leaf_paths 补全 labels_list 中的缺失值（-1）。

#     Args:
#         labels_list: List[torch.Tensor]，每一行表示一个层次的标签，-1 为缺失值。
#         label_tree: Tree，表示层次结构的树。

#     Returns:
#         List[torch.Tensor]: 补全后的 labels_list。
#     """
#     # 获取所有叶子节点及其路径
#     leaf_paths = get_leaf_paths(label_tree)  # {leaf_name: [path_to_leaf]}

#     # 获取叶子节点编号到路径的映射
#     leaf_to_path = {}
#     for leaf, path in leaf_paths.items():
#         # 从路径中提取叶子节点的编号
#         leaf_to_path[int(leaf.split("_")[1])] = path

#     # 获取最深层次
#     num_levels = len(labels_list)

#     # 遍历 labels_list 的最后一行，尝试匹配叶子节点
#     last_layer = labels_list[-1]  # 最后一层的标签
#     for i, label in enumerate(last_layer):
#         if label == -1:
#             continue  # 跳过缺失值

#         # 找到对应叶子节点的路径
#         path = leaf_to_path[int(label.item())]  # 根据叶子编号找到路径

#         # 补全上层的标签（包括 root 节点）
#         for level, node in enumerate(path):  # 包括 root 节点
#             level_index = int(node.split("_")[1]) if "_" in node else 0  # 根节点编号为 0
#             if labels_list[level-1][i] == -1:  # 仅在缺失值时补全
#                 labels_list[level-1][i] = level_index

#     return labels_list






from typing import List, Dict, Optional, Tuple
import torch
from collections import deque
import re

_LEVEL_ID_RE = re.compile(r"^L(\d+)_(\d+)$")

def _parse_level_id(name: str) -> Optional[Tuple[int, int]]:
    """
    解析节点名，期望形如 'L{level}_{id}'，返回 (level, id)；否则返回 None。
    """
    m = _LEVEL_ID_RE.match(name)
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))

def _build_maps_all_nodes(label_tree) -> Tuple[Dict[int, str], Dict[str, List[str]], set]:
    """
    遍历整棵树，构建：
    - id_to_name: id -> 'Lm_id'
    - name_to_prefix_path: 'Lm_id' -> ['L1_x', 'L2_y', ..., 'Lm_id']（不包含 root 等无法解析层号的名字）
    - leaf_names: 叶子节点名集合（仅限可解析为 'Lm_id' 的节点）
    """
    id_to_name: Dict[int, str] = {}
    name_to_prefix_path: Dict[str, List[str]] = {}
    leaf_names: set = set()

    q = deque()
    # 队列元素: (node, prefix_names)；prefix_names 仅包含已解析 'Lm_id' 的名字
    q.append((label_tree, []))

    while q:
        node, prefix = q.popleft()
        name = getattr(node, "name", "")
        parsed = _parse_level_id(name)

        if parsed is not None:
            level_m, id_n = parsed
            cur = prefix + [name]
            id_to_name[id_n] = name
            if name not in name_to_prefix_path:
                name_to_prefix_path[name] = cur
        else:
            cur = prefix  # root 或无法解析层号的名字，不计入路径

        children = getattr(node, "children", [])
        if len(children) == 0:
            # 叶子：仅记录可解析的名字
            if parsed is not None:
                leaf_names.add(name)
        else:
            for ch in children:
                q.append((ch, cur))

    return id_to_name, name_to_prefix_path, leaf_names

def complete_labels(labels_list: List[torch.Tensor], label_tree) -> List[torch.Tensor]:
    """
    根据 label_tree 补全 labels_list 中的缺失值（-1）。

    约定与策略：
    - 节点名遵循 'L{m}_{id}'，用 m 与 labels_list 的第 m-1 行对齐（第0行是 L1 层）。
    - 仅向上补全祖先，不向下生成子节点。
    - 只在缺失(-1)时填充，不覆盖已有值。
    - 对树的不完整/缺失保持稳健：若某 id 不在树中则跳过。
    """
    assert len(labels_list) > 0
    D = len(labels_list)
    B = labels_list[0].shape[0]
    assert all(t.shape[0] == B for t in labels_list), "All rows must have same batch size"

    # 构建整树索引，避免仅从叶子推导造成的遗漏
    id_to_name, name_to_prefix_path, _leaf_names = _build_maps_all_nodes(label_tree)

    out = [t.clone() for t in labels_list]

    # 策略：对每一列，扫描各层已知标签；对每个已知节点，向上补齐其祖先（按层号对齐行）
    for i in range(B):
        for row in range(D):
            val = int(out[row][i].item())
            if val == -1:
                continue

            node_name = id_to_name.get(val)
            if node_name is None:
                # 树中无此 id（树缺失或 id 空间不一致），跳过
                continue

            prefix = name_to_prefix_path.get(node_name)
            if not prefix:
                continue

            # 将前缀路径中的每个 'Lm_id' 映射到第 m-1 行，仅在 -1 时填充
            for name in prefix:
                parsed = _parse_level_id(name)
                if parsed is None:
                    continue
                level_m, id_n = parsed
                row_idx = level_m - 1  # 第 m 层 -> 第 m-1 行（与你的数据示例一致）
                if 0 <= row_idx < D and int(out[row_idx][i].item()) == -1:
                    out[row_idx][i] = id_n

            # 不向下补：prefix 只到当前节点为止

    return out

















def pseudo_tree_ce_loss(
    logits_list: List[torch.Tensor],
    labels_list: List[torch.Tensor],
    label_tree: Tree,
    aggregation: str = 'mean', # Options: 'mean', 'sum', 'none'
) -> Union[torch.Tensor, List[torch.Tensor], None]:

    if len(logits_list) != len(labels_list):
        raise ValueError(f"logits_list (length {len(logits_list)}) and labels_list (length {len(labels_list)}) must have the same length.")

    
    # Determine the criterion to use
    current_criterion = nn.CrossEntropyLoss()

    # 补全 labels_list
    labels_list = complete_labels(labels_list, label_tree)

    # 用于存储每个样本的损失
    sample_losses = torch.zeros_like(labels_list[0], dtype=torch.float)  # [batch_size]

    for i, (logits, labels) in enumerate(zip(logits_list, labels_list)):
        # Ensure labels are of type long
        if labels.dtype != torch.long:
             labels = labels.long()
        
        # 跳过 labels 为 -1 的样本
        valid_mask = labels != -1  # 标记为 True 的位置是有效标签
        if valid_mask.sum() == 0:  # 如果没有有效标签，跳过这个 logits 和 labels 对
            continue

        # Calculate loss for the current pair
        valid_logits = logits[valid_mask]  # 过滤有效的 logits
        valid_labels = labels[valid_mask]  # 过滤有效的 labels
        loss = current_criterion(valid_logits, valid_labels)
        
        # 将损失填回原始样本位置
        # print(valid_mask.shape, sample_losses.shape)
        sample_losses[valid_mask] += loss

    if aggregation == 'sum':
        return sample_losses.sum()
    elif aggregation == 'mean':
        return sample_losses.mean()
    elif aggregation == 'individual':
        return sample_losses.tolist()
    else:
        raise ValueError(f"Invalid aggregation method: '{aggregation}'. Choose 'mean', 'sum'.")

# --- Example Usage ---

if __name__ == '__main__':
    tree_fname = "./data/cifar100/cifar_100_tree.pkl"
    name = 'cifar'
    with open(tree_fname, "rb") as f:
        label_tree = pickle.load(f)
    print(label_tree)
    labels_list = [
        torch.tensor([0, -1, -1, -1, -1, -1]),  # 第一层
        torch.tensor([-1, -1, -1, -1, -1, -1]), # 第二层
        torch.tensor([9, 22, 2, 3, 4, 5])       # 第三层（叶子节点层）
    ]
    completed_labels = complete_labels(labels_list, label_tree)
    print(completed_labels)