import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict
from nltk.tree import Tree
import math
import pickle
import numpy as np

def iter_all_node_names_via_levels(tree, max_empty_streak: int = 3):
    level = 0
    empty = 0
    seen = set()
    while empty < max_empty_streak:
        nodes = get_nodes_at_level(tree, level)  # 应返回节点名列表或节点对象列表
        if nodes:
            empty = 0
            for n in nodes:
                name = n if isinstance(n, str) else getattr(n, "label", str(n))
                if name not in seen:
                    seen.add(name)
                    yield name
        else:
            empty += 1
        level += 1

def _get_max_level_from_names(tree) -> int:
    # 解析所有出现过的层号 m（来自 "Lm_n"），返回其最大值
    max_m = 0
    for name in iter_all_node_names_via_levels(tree):
        # 允许 name 不是严格 Lm_n 格式时跳过
        if isinstance(name, str) and name.startswith("L") and "_" in name:
            try:
                m = int(name.split("_")[0][1:])  # 'L3_149' -> 3
                if m > max_m:
                    max_m = m
            except Exception:
                pass
    return max_m  # 若只有 root L0_*，则返回 0；若最深到 L3_*，则返回 3

def get_leaf_paths(tree: Tree, path=None): # 获取树的所有叶子节点及其路径
    if path is None:
        path = []
    if isinstance(tree, str) and tree.startswith("L"):  # 叶子节点
        return {tree: path + [tree]}
    leaf_paths = {}
    for subtree in tree:
        leaf_paths.update(get_leaf_paths(subtree, path + [tree.label()]))
    return leaf_paths

# 补全 labels_list 的函数
from typing import List, Dict, Optional, Tuple
import torch
from collections import deque
import re

_LEVEL_ID_RE = re.compile(r"^L(\d+)_(\d+)$")

def _parse_level_id(name: str) -> Optional[Tuple[int, int]]:
    """
    解析节点名，期望形如 'L{level}_{id}'，返回 (level, id)；否则返回 None。
    """
    m = _LEVEL_ID_RE.match(name)
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))

def _build_maps_all_nodes(label_tree) -> Tuple[Dict[int, str], Dict[str, List[str]], set]:
    """
    遍历整棵树，构建：
    - id_to_name: id -> 'Lm_id'
    - name_to_prefix_path: 'Lm_id' -> ['L1_x', 'L2_y', ..., 'Lm_id']（不包含 root 等无法解析层号的名字）
    - leaf_names: 叶子节点名集合（仅限可解析为 'Lm_id' 的节点）
    """
    id_to_name: Dict[int, str] = {}
    name_to_prefix_path: Dict[str, List[str]] = {}
    leaf_names: set = set()

    q = deque()
    # 队列元素: (node, prefix_names)；prefix_names 仅包含已解析 'Lm_id' 的名字
    q.append((label_tree, []))

    while q:
        node, prefix = q.popleft()
        name = getattr(node, "name", "")
        parsed = _parse_level_id(name)

        if parsed is not None:
            level_m, id_n = parsed
            cur = prefix + [name]
            id_to_name[id_n] = name
            if name not in name_to_prefix_path:
                name_to_prefix_path[name] = cur
        else:
            cur = prefix  # root 或无法解析层号的名字，不计入路径

        children = getattr(node, "children", [])
        if len(children) == 0:
            # 叶子：仅记录可解析的名字
            if parsed is not None:
                leaf_names.add(name)
        else:
            for ch in children:
                q.append((ch, cur))

    return id_to_name, name_to_prefix_path, leaf_names

def complete_labels(labels_list: List[torch.Tensor], label_tree) -> List[torch.Tensor]:
    """
    根据 label_tree 补全 labels_list 中的缺失值（-1）。

    约定与策略：
    - 节点名遵循 'L{m}_{id}'，用 m 与 labels_list 的第 m-1 行对齐（第0行是 L1 层）。
    - 仅向上补全祖先，不向下生成子节点。
    - 只在缺失(-1)时填充，不覆盖已有值。
    - 对树的不完整/缺失保持稳健：若某 id 不在树中则跳过。
    """
    assert len(labels_list) > 0
    D = len(labels_list)
    B = labels_list[0].shape[0]
    assert all(t.shape[0] == B for t in labels_list), "All rows must have same batch size"

    # 构建整树索引，避免仅从叶子推导造成的遗漏
    id_to_name, name_to_prefix_path, _leaf_names = _build_maps_all_nodes(label_tree)

    out = [t.clone() for t in labels_list]

    # 策略：对每一列，扫描各层已知标签；对每个已知节点，向上补齐其祖先（按层号对齐行）
    for i in range(B):
        for row in range(D):
            val = int(out[row][i].item())
            if val == -1:
                continue

            node_name = id_to_name.get(val)
            if node_name is None:
                # 树中无此 id（树缺失或 id 空间不一致），跳过
                continue

            prefix = name_to_prefix_path.get(node_name)
            if not prefix:
                continue

            # 将前缀路径中的每个 'Lm_id' 映射到第 m-1 行，仅在 -1 时填充
            for name in prefix:
                parsed = _parse_level_id(name)
                if parsed is None:
                    continue
                level_m, id_n = parsed
                row_idx = level_m - 1  # 第 m 层 -> 第 m-1 行（与你的数据示例一致）
                if 0 <= row_idx < D and int(out[row_idx][i].item()) == -1:
                    out[row_idx][i] = id_n

            # 不向下补：prefix 只到当前节点为止

    return out

def get_nodes_at_level(tree: Tree, m: int):
    """
    从树中提取所有节点的标签，并筛选出层级为 m 的节点（格式为 Lm_n 的字符串）。

    Args:
        tree (Tree): 树对象。
        m (int): 目标层级（从 0 开始计数）。

    Returns:
        list: 第 m 层节点标签（格式为 Lm_n 的字符串）。
    """
    # 提取所有节点标签
    all_nodes = extract_all_nodes(tree)

    # 筛选层级为 m 的节点
    result = [node for node in all_nodes if node.startswith(f"L{m}_")]
    return result


def extract_all_nodes(tree):
    """
    提取树中所有节点的标签（包括叶子节点和非叶子节点）。

    Args:
        tree (Tree): 树对象。

    Returns:
        list: 树中所有节点的标签（格式为 Lm_n 的字符串）。
    """
    if isinstance(tree, str):  # 如果是叶子节点，直接返回
        return [tree]
    nodes = [tree.label()]  # 包括当前节点的标签
    for subtree in tree:
        nodes.extend(extract_all_nodes(subtree))  # 递归提取子树的节点标签
    return nodes

def get_all_paths(tree: Tree, path=None):  
    """
    获取树中所有节点（包括叶子节点和非叶子节点）的路径。

    Args:
        tree (Tree): 树对象。
        path (list): 当前路径（递归过程中使用）。

    Returns:
        dict: {node_name: path_to_node}
    """
    if path is None:
        path = []

    # 创建结果字典
    all_paths = {}

    # 如果是叶子节点，直接添加路径
    if isinstance(tree, str) and tree.startswith("L"):  # 假设叶子节点以 "L" 开头
        all_paths[tree] = path + [tree]
        return all_paths

    # 如果是非叶子节点，先记录当前节点的路径
    if not isinstance(tree, str):  # 非叶子节点是 Tree 对象
        all_paths[tree.label()] = path + [tree.label()]

    # 遍历子树，递归获取子节点路径
    for subtree in tree:
        all_paths.update(get_all_paths(subtree, path + [tree.label()]))

    return all_paths

def compute_dist(tree: Tree, label1: str, label2: str, beta=1) -> int:
    """
    找到两个标签的最近公共祖先（LCA）的高度。

    Args:
        tree (Tree): 树对象。
        label1 (str): 第一个标签。
        label2 (str): 第二个标签。

    Returns:
        int: 最近公共祖先的高度（从 LCA 到叶子的最长路径）。
    """
    total_height = tree.height()  # 树的总高度
    all_paths = get_all_paths(tree)
    if label1 not in all_paths:
        # print(f"Warning: Predicted label '{pred_label}' not found in leaf_paths")
        return math.exp(-beta * 100)  # 或者返回一个默认值，或者抛出自定义异常
        
    if label2 not in all_paths:
        # print(f"Warning: True label '{true_label}' not found in leaf_paths")
        return math.exp(-beta * 100)  # 或者返回一个默认值，或者抛出自定义异常
    
    if label1 == label2:
        return math.exp(-beta * 1/total_height)  # 如果预测和真实标签相同，LCA height 为 0
    
    level1 = int(label1.split('_')[0][1:])
    level2 = int(label2.split('_')[0][1:])
    if level1 != level2:
        return -1
    
    path1 = all_paths[label1] 
    path2 = all_paths[label2]
    # 找到公共前缀的长度
    lca_depth = len([node for node, tnode in zip(path1, path2) if node == tnode])
    # 计算从 LCA 到真实类别的高度
    true_height = level1 - lca_depth + 2
    # print(true_height)
    return math.exp(-beta * true_height/total_height)  # 返回指数衰减的距离
    
def get_soft_labels_dict(tree: Tree, beta=1, num_classes:list = None) -> Dict[int, np.ndarray]:
    """
    为树中每个节点计算软标签并存储在字典中。

    Args:
        tree (Tree): 树对象。
        beta (float): 控制距离对软标签的影响，越大表示距离影响越显著。

    Returns:
        dict: 每个节点的软标签，格式为 {label: {other_label: soft_label}}
    """
    soft_labels = {}  # 存储所有节点的软标签
    max_level = _get_max_level_from_names(tree)

    # 遍历每一层的节点
    for level in range(1, max_level + 1):  # 从第 1 层开始（跳过 root）
        nodes = get_nodes_at_level(tree, level)  # 获取当前层的所有节点
        # 找到同层中id最大的节点编号 max_n: Lm_n
        # num_max = max((int(s.split('_')[1]) for s in nodes if s.startswith(f"L{level}_")), default=None) + 1
        if not nodes:
            continue
        num_max = num_classes[level-1] if (num_classes is not None and level - 1 < len(num_classes)) else len(nodes)

        for i, node in enumerate(nodes):
            distances = np.ones(num_max) * math.exp(-beta * 100)  # 初始化距离张量为 0
            node_num = int(node.split("_")[1])
            # 计算到其他节点的距离
            for j, other_node in enumerate(nodes):
                other_node_num = int(other_node.split("_")[1])
                distance = compute_dist(tree, node, other_node, beta=beta)
                if distance == -1:
                    continue
                distances[other_node_num] = distance
             # 归一化距离张量，使其和为 1
            distances /= distances.sum()

            # 存储张量
            soft_labels[node] = distances

    return soft_labels



def soft_ce_loss(
    logits_list: List[torch.Tensor],
    labels_list: List[torch.Tensor],
    label_tree: Tree,  # 描述层次结构的树
    aggregation: str = 'mean',  # Options: 'mean', 'sum', or 'none'
    num_classes: list = None,  # 每一层的类别数
    beta: float = 5          # 控制权重衰减的参数
) -> torch.Tensor:
    """
    计算层次交叉熵损失 (Hierarchical Cross-Entropy, HXE)。

    Args:
        logits_list (List[torch.Tensor]): 每一层的 logits，形状为 [batch_size, num_classes_l]。
        labels_list (List[torch.Tensor]): 每一层的标签，形状为 [batch_size]，其中 -1 表示缺失值。
        label_tree (Tree): 描述层次结构的树。
        aggregation (str): 聚合方式，可选 'mean', 'sum', or 'none'。
        alpha (float): 权重衰减参数，用于计算 \lambda(C^{(l)})。

    Returns:
        torch.Tensor: 层次交叉熵损失。
    """

    soft_labels_dict = get_soft_labels_dict(label_tree, beta=beta, num_classes=num_classes)

    # 补全 labels_list
    labels_list = complete_labels(labels_list, label_tree)

    if len(logits_list) != len(labels_list):
        raise ValueError("logits_list and labels_list must have the same length.")

    # current_criterion = nn.CrossEntropyLoss(reduction='none')  # 使用无归约的交叉熵损失
    # current_criterion = nn.KLDivLoss(reduction='none')
    # 获取层数
    num_layers = len(logits_list)

    # 初始化总损失
    sample_losses = torch.zeros_like(labels_list[0], dtype=torch.float)  # [batch_size]

    for i, (logits, labels) in enumerate(zip(logits_list, labels_list)):
        # Ensure labels are of type long
        if labels.dtype != torch.long:
             labels = labels.long()
        
        # 跳过 labels 为 -1 的样本
        valid_mask = labels != -1  # 标记为 True 的位置是有效标签
        if valid_mask.sum() == 0:  # 如果没有有效标签，跳过这个 logits 和 labels 对
            continue

        # Calculate loss for the current pair
        valid_logits = logits[valid_mask]  
        valid_labels = labels[valid_mask]  
        node_names = [f"L{i+1}_{n}" for n in valid_labels]
        soft_labels = [torch.from_numpy(soft_labels_dict[key]) for key in node_names]  # 获取对应的软标签
        soft_labels = torch.stack(soft_labels, dim=0).to(logits.device)  # [batch_size, num_classes_l]

        # 使用 KL 散度计算损失 (logits 需要通过 log_softmax)
        log_probs = F.log_softmax(valid_logits, dim=1)  # 对 logits 取 log_softmax
        loss = -torch.sum(soft_labels * log_probs, dim=1)
        
        # 将损失填回原始样本位置
        sample_losses[valid_mask] += loss

    if aggregation == 'sum':
        return sample_losses.sum()
    elif aggregation == 'mean':
        return sample_losses.mean()
    elif aggregation == 'individual':
        return sample_losses.tolist()
    else:
        raise ValueError(f"Invalid aggregation method: '{aggregation}'. Choose 'mean', 'sum'.")

# --- Example Usage ---

if __name__ == '__main__':
    logits_list = [
        torch.randn(6, 100),  # 第一层 logits (batch_size=4, num_classes=3)
        torch.randn(6, 100),  # 第二层 logits (batch_size=4, num_classes=5)
        torch.randn(6, 100)   # 第三层 logits (batch_size=4, num_classes=6)
    ]

    labels_list = [
        torch.tensor([0, -1, -1, -1, -1, -1]),  # 第一层
        torch.tensor([-1, -1, -1, -1, -1, -1]), # 第二层
        torch.tensor([9, 22, 2, 3, 4, 5])       # 第三层（叶子节点层）
    ]

    # 假设一个简单的层次树
    tree_fname = "./data/cifar100/cifar_100_tree.pkl"
    # tree_fname = "./data/fgvc/fgvc_label_hierarchy_tree.pkl"
    name = 'cifar'
    # name = 'fgvc'
    with open(tree_fname, "rb") as f:
        label_tree = pickle.load(f)
    print(label_tree)
    
    soft_labels_dict = get_soft_labels_dict(label_tree, beta=5, num_classes=[100,100,100])

    print(soft_labels_dict)
    loss = soft_ce_loss(
        logits_list=logits_list,
        labels_list=labels_list,
        label_tree=label_tree,
        aggregation='mean',
        beta=5,
        num_classes=[100,100,100]
    )
    print(f"Loss: {loss}")
    

    