import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Union, Tuple, Dict
from nltk.tree import Tree
import pickle
import re
from collections import deque

def get_leaf_paths(tree: Tree, path=None): # 获取树的所有叶子节点及其路径
    if path is None:
        path = []
    if isinstance(tree, str) and tree.startswith("L"):  # 叶子节点
        return {tree: path + [tree]}
    leaf_paths = {}
    for subtree in tree:
        leaf_paths.update(get_leaf_paths(subtree, path + [tree.label()]))
    return leaf_paths

# 解析 "Lm_n"
_LMN_RE = re.compile(r"^L(\d+)_(\d+)$")

def _parse_level_id(name: str):
    m = _LMN_RE.match(name)
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))  # (level_m, id_n)

# 补全 labels_list 的函数
def complete_labels(labels_list: List[torch.Tensor], label_tree: Tree) -> List[torch.Tensor]:
    """
    根据 label_tree 和 leaf_paths 补全 labels_list 中的缺失值（-1）。

    Args:
        labels_list: List[torch.Tensor]，每一行表示一个层次的标签，-1 为缺失值。
        label_tree: Tree，表示层次结构的树。

    Returns:
        List[torch.Tensor]: 补全后的 labels_list。
    """
    # 获取所有叶子节点及其路径
    leaf_paths = get_leaf_paths(label_tree)  # {leaf_name: [path_to_leaf]}

    # 获取叶子节点编号到路径的映射
    leaf_to_path = {}
    for leaf, path in leaf_paths.items():
        # 从路径中提取叶子节点的编号
        leaf_to_path[int(leaf.split("_")[1])] = path

    # 获取最深层次
    num_levels = len(labels_list)

    # 遍历 labels_list 的最后一行，尝试匹配叶子节点
    last_layer = labels_list[-1]  # 最后一层的标签
    for i, label in enumerate(last_layer):
        if label == -1:
            continue  # 跳过缺失值

        # 找到对应叶子节点的路径
        path = leaf_to_path[int(label.item())]  # 根据叶子编号找到路径

        # 补全上层的标签（包括 root 节点）
        for level, node in enumerate(path):  # 包括 root 节点
            level_index = int(node.split("_")[1]) if "_" in node else 0  # 根节点编号为 0
            if labels_list[level-1][i] == -1:  # 仅在缺失值时补全
                labels_list[level-1][i] = level_index

    return labels_list






from typing import List, Dict, Optional, Tuple
import torch
from collections import deque
import re

_LEVEL_ID_RE = re.compile(r"^L(\d+)_(\d+)$")

def _parse_level_id(name: str) -> Optional[Tuple[int, int]]:
    """
    解析节点名，期望形如 'L{level}_{id}'，返回 (level, id)；否则返回 None。
    """
    m = _LEVEL_ID_RE.match(name)
    if not m:
        return None
    return int(m.group(1)), int(m.group(2))

def _build_maps_all_nodes(label_tree) -> Tuple[Dict[int, str], Dict[str, List[str]], set]:
    """
    遍历整棵树，构建：
    - id_to_name: id -> 'Lm_id'
    - name_to_prefix_path: 'Lm_id' -> ['L1_x', 'L2_y', ..., 'Lm_id']（不包含 root 等无法解析层号的名字）
    - leaf_names: 叶子节点名集合（仅限可解析为 'Lm_id' 的节点）
    """
    id_to_name: Dict[int, str] = {}
    name_to_prefix_path: Dict[str, List[str]] = {}
    leaf_names: set = set()

    q = deque()
    # 队列元素: (node, prefix_names)；prefix_names 仅包含已解析 'Lm_id' 的名字
    q.append((label_tree, []))

    while q:
        node, prefix = q.popleft()
        name = getattr(node, "name", "")
        parsed = _parse_level_id(name)

        if parsed is not None:
            level_m, id_n = parsed
            cur = prefix + [name]
            id_to_name[id_n] = name
            if name not in name_to_prefix_path:
                name_to_prefix_path[name] = cur
        else:
            cur = prefix  # root 或无法解析层号的名字，不计入路径

        children = getattr(node, "children", [])
        if len(children) == 0:
            # 叶子：仅记录可解析的名字
            if parsed is not None:
                leaf_names.add(name)
        else:
            for ch in children:
                q.append((ch, cur))

    return id_to_name, name_to_prefix_path, leaf_names

def complete_labels(labels_list: List[torch.Tensor], label_tree) -> List[torch.Tensor]:
    """
    根据 label_tree 补全 labels_list 中的缺失值（-1）。

    约定与策略：
    - 节点名遵循 'L{m}_{id}'，用 m 与 labels_list 的第 m-1 行对齐（第0行是 L1 层）。
    - 仅向上补全祖先，不向下生成子节点。
    - 只在缺失(-1)时填充，不覆盖已有值。
    - 对树的不完整/缺失保持稳健：若某 id 不在树中则跳过。
    """
    assert len(labels_list) > 0
    D = len(labels_list)
    B = labels_list[0].shape[0]
    assert all(t.shape[0] == B for t in labels_list), "All rows must have same batch size"

    # 构建整树索引，避免仅从叶子推导造成的遗漏
    id_to_name, name_to_prefix_path, _leaf_names = _build_maps_all_nodes(label_tree)

    out = [t.clone() for t in labels_list]

    # 策略：对每一列，扫描各层已知标签；对每个已知节点，向上补齐其祖先（按层号对齐行）
    for i in range(B):
        for row in range(D):
            val = int(out[row][i].item())
            if val == -1:
                continue

            node_name = id_to_name.get(val)
            if node_name is None:
                # 树中无此 id（树缺失或 id 空间不一致），跳过
                continue

            prefix = name_to_prefix_path.get(node_name)
            if not prefix:
                continue

            # 将前缀路径中的每个 'Lm_id' 映射到第 m-1 行，仅在 -1 时填充
            for name in prefix:
                parsed = _parse_level_id(name)
                if parsed is None:
                    continue
                level_m, id_n = parsed
                row_idx = level_m - 1  # 第 m 层 -> 第 m-1 行（与你的数据示例一致）
                if 0 <= row_idx < D and int(out[row_idx][i].item()) == -1:
                    out[row_idx][i] = id_n

            # 不向下补：prefix 只到当前节点为止

    return out















def _count_valid(labels: torch.Tensor, C: int) -> torch.Tensor:
    """
    统计有效标签个数（用于平均的分母）。
    labels: [B], int64, 取值 -1 或 [0..C-1]
    返回: scalar(int64)
    """
    return ((labels >= 0) & (labels < C)).sum()

def clst_sep_loss(
    prototypes: torch.Tensor,   # [C, M, D]
    features_z: torch.Tensor,   # [B, S, D]
    labels: torch.Tensor,       # [B]，int64，取值范围 [0, C-1]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    计算 Clst 与 Sep 两个损失（向量化实现）。

    返回:
      - 当 reduction="none"（默认）时：clst_per_sample, sep_per_sample，形状都是 [B]
      - 当 reduction="mean" 或 "sum" 时：返回标量（0D tensor）

    定义（对样本 i）:
      clst_i = min_{j∈P_{y_i}} min_{z∈patches(x_i)} ||z - p_j||^2
      sep_i  = - min_{j∉P_{y_i}} min_{z∈patches(x_i)} ||z - p_j||^2
    """
    assert prototypes.dim() == 3 and features_z.dim() == 3 and labels.dim() == 1
    C, M, Dp = prototypes.shape        # [C,M,D]
    B, S, Dz = features_z.shape        # [B,S,D]
    assert Dp == Dz, f"feature dim mismatch: prototypes D={Dp}, features D={Dz}"
    assert labels.shape[0] == B, "labels length must equal batch size"
    assert prototypes.is_floating_point() and features_z.is_floating_point()

    # 1) 所有 z 对所有 p 的平方欧氏距离 d2[b,s,c,m]
    z2 = (features_z ** 2).sum(dim=-1, keepdim=True)              # [B,S,1]
    p2 = (prototypes  ** 2).sum(dim=-1)                           # [C,M]
    cross = torch.einsum("bsd,cmd->bscm", features_z, prototypes) # [B,S,C,M]
    d2 = z2[..., None] + p2[None, None, :, :] - 2.0 * cross       # [B,S,C,M]
    d2 = d2.clamp_min_(0.0)

    # 2) Clst：选所属类别层，min_M → min_S → 得到每样本的 clst_i
    y = labels.to(dtype=torch.long)
    y_idx = y.view(B, 1, 1, 1).expand(B, S, 1, M)                 # [B,S,1,M]
    d2_same = torch.gather(d2, dim=2, index=y_idx).squeeze(2)     # [B,S,M]
    clst_per_sample = d2_same.min(dim=2).values.min(dim=1).values # [B]

    # 3) Sep：屏蔽同类，min_M → min_C → min_S → 每样本 sep_i
    if C == 1:
        sep_per_sample = torch.zeros(B, device=d2.device, dtype=d2.dtype)  # [B]
    else:
        same_mask = F.one_hot(y, num_classes=C).to(dtype=torch.bool)       # [B,C]
        same_mask = same_mask.view(B, 1, C, 1).expand(B, S, C, M)          # [B,S,C,M]
        d2_other = d2.masked_fill(same_mask, float("inf"))                 # [B,S,C,M]
        sep_per_sample = (
            -d2_other.min(dim=3).values   # min over M → [B,S,C]
                   .min(dim=2).values     # min over C → [B,S]
                   .min(dim=1).values     # min over S → [B]
        )  # [B]

    # if reduction == "none":
    return clst_per_sample, sep_per_sample

def proto_loss(
    logits_list: List[torch.Tensor],
    labels_list: List[torch.Tensor],
    protofeatures_list: List[torch.Tensor],
    prototypes_list: List[torch.Tensor],
    label_tree: Tree,
    aggregation: str = 'mean', # Options: 'mean', 'sum', 'none'
) -> Union[torch.Tensor, List[torch.Tensor], None]:

    if len(logits_list) != len(labels_list):
        raise ValueError(f"logits_list (length {len(logits_list)}) and labels_list (length {len(labels_list)}) must have the same length.")
    
    # Determine the criterion to use
    current_criterion = nn.CrossEntropyLoss()

    # 补全 labels_list
    labels_list = complete_labels(labels_list, label_tree)

    # 用于存储每个样本的损失
    sample_losses = torch.zeros_like(labels_list[0], dtype=torch.float)  # [batch_size]

    for i, (logits, labels, features, prototypes) in enumerate(zip(logits_list, labels_list, protofeatures_list, prototypes_list)):
        # Ensure labels are of type long
        if labels.dtype != torch.long:
             labels = labels.long()
        
        # 跳过 labels 为 -1 的样本
        valid_mask = labels != -1  # 标记为 True 的位置是有效标签
        if valid_mask.sum() == 0:  # 如果没有有效标签，跳过这个 logits 和 labels 对
            continue

        # Calculate loss for the current pair
        valid_logits = logits[valid_mask]  # 过滤有效的 logits
        valid_labels = labels[valid_mask]  # 过滤有效的 labels
        valid_features = features[valid_mask]
        sup_loss = current_criterion(valid_logits, valid_labels)
        clst_loss, sep_loss = clst_sep_loss(prototypes, valid_features, valid_labels)
        
        # 将损失填回原始样本位置
        sample_losses[valid_mask] += sup_loss + 0.8 * clst_loss + 0.08 * sep_loss

    if aggregation == 'sum':
        return sample_losses.sum()
    elif aggregation == 'mean':
        return sample_losses.mean()
    elif aggregation == 'individual':
        return sample_losses.tolist()
    else:
        raise ValueError(f"Invalid aggregation method: '{aggregation}'. Choose 'mean', 'sum'.")

# --- Example Usage ---

if __name__ == '__main__':
    tree_fname = "./data/cifar100/cifar_100_tree.pkl"
    name = 'cifar'
    with open(tree_fname, "rb") as f:
        label_tree = pickle.load(f)
    print(label_tree)
    labels_list = [
        torch.tensor([0, -1, -1, -1, -1, -1]),  # 第一层
        torch.tensor([-1, -1, -1, -1, -1, -1]), # 第二层
        torch.tensor([9, 22, 2, 3, 4, 5])       # 第三层（叶子节点层）
    ]
    completed_labels = complete_labels(labels_list, label_tree)
    print(completed_labels)