import torch
import torch.nn as nn
from typing import List, Optional, Union
from nltk.tree import Tree

def multi_pair_ce_loss(
    logits_list: List[torch.Tensor],
    labels_list: List[torch.Tensor],
    aggregation: str = 'mean', # Options: 'mean', 'sum', 'none'
    label_tree: Tree = None  # Not used in this function, but can be included for future use
) -> Union[torch.Tensor, List[torch.Tensor], None]:

    if len(logits_list) != len(labels_list):
        raise ValueError(f"logits_list (length {len(logits_list)}) and labels_list (length {len(labels_list)}) must have the same length.")

    # Determine the criterion to use

    current_criterion = nn.CrossEntropyLoss()

    # 用于存储每个样本的损失
    sample_losses = torch.zeros_like(labels_list[0], dtype=torch.float)  # [batch_size]

    for i, (logits, labels) in enumerate(zip(logits_list, labels_list)):
        # Ensure labels are of type long
        if labels.dtype != torch.long:
             labels = labels.long()
        
        # 跳过 labels 为 -1 的样本
        valid_mask = labels != -1  # 标记为 True 的位置是有效标签
        if valid_mask.sum() == 0:  # 如果没有有效标签，跳过这个 logits 和 labels 对
            continue

        # Calculate loss for the current pair
        valid_logits = logits[valid_mask]  # 过滤有效的 logits
        valid_labels = labels[valid_mask]  # 过滤有效的 labels
        loss = current_criterion(valid_logits, valid_labels)
        
        if(i == len(logits_list)-1):
        # 将损失填回原始样本位置
            sample_losses[valid_mask] += loss

    if aggregation == 'sum':
        return sample_losses.sum()
    elif aggregation == 'mean':
        return sample_losses.mean()
    elif aggregation == 'individual':
        return sample_losses.tolist()
    else:
        raise ValueError(f"Invalid aggregation method: '{aggregation}'. Choose 'mean', 'sum'.")

# --- Example Usage ---

if __name__ == '__main__':
    # --- Setup ---
    # Create a reusable criterion instance (recommended)
    ce_criterion = nn.CrossEntropyLoss()

    # Example Data Pair 1 (e.g., output from one head)
    logits1 = torch.randn(4, 10, requires_grad=True) # Batch 4, 10 classes
    labels1 = torch.randint(0, 10, (4,))            # Batch 4 labels

    # Example Data Pair 2 (e.g., output from another head, maybe different classes/batch)
    logits2 = torch.randn(3, 5, requires_grad=True) # Batch 3, 5 classes
    labels2 = torch.tensor([0, 4, 2])               # Batch 3 labels (long type)

    # Example Data Pair 3 (with labels needing casting)
    logits3 = torch.randn(2, 8, requires_grad=True) # Batch 2, 8 classes
    labels3 = torch.tensor([1, 6], dtype=torch.int) # Labels as int, not long

    list_of_logits = [logits1, logits2, logits3]
    list_of_labels = [labels1, labels2, labels3]




    sum_loss = multi_pair_ce_loss(list_of_logits, list_of_labels, aggregation='sum')
    print(f"Sum Loss (creating criterion internally): {sum_loss.item()}")
    mean_loss = multi_pair_ce_loss(list_of_logits, list_of_labels, aggregation='mean')
    print(f"Mean Loss (creating criterion internally): {mean_loss.item()}")