import torch
from torch.utils.data import DataLoader, Subset



def create_unlearn_train_loaders(original_train_loaders, unlearn_client, unlearn_class):
    """
    生成移除指定类别的新训练数据加载器列表（支持多类别过滤）
    
    Args:
        original_train_loaders: 原始训练数据加载器列表
        unlearn_client: 需要遗忘的客户端索引
        unlearn_class: 需要移除的类别标签（支持单个int或list）
    
    Returns:
        list: 处理后的新数据加载器列表
    """
    unlearn_train_loaders = []
    
    if not isinstance(unlearn_class, (list, tuple)):
        unlearn_class = [unlearn_class]
    
    unlearn_set = set(unlearn_class)
    
    for client_idx, loader in enumerate(original_train_loaders):
        if client_idx == unlearn_client:
            original_dataset = loader.dataset
            
            filtered_indices = []
            for idx in range(len(original_dataset)):
                data, label = original_dataset[idx]
                if label not in unlearn_set:
                    filtered_indices.append(idx)
            
            filtered_dataset = Subset(original_dataset, filtered_indices)
            
            new_loader = DataLoader(
                filtered_dataset,
                batch_size=loader.batch_size,
                shuffle=loader.shuffle,
                num_workers=loader.num_workers,
                pin_memory=loader.pin_memory,
                drop_last=loader.drop_last
            )
            unlearn_train_loaders.append(new_loader)
        else:
            unlearn_train_loaders.append(loader)
    
    return unlearn_train_loaders