import torch
import torch.nn as nn
import torch.nn.functional as F

def build_group_tensors(
    class_group: dict, 
    num_classes: int
):
    """
    예) class_group = {
          'vehicle': [1,2,3,4,5],
          'human': [6,7,8],
          ...
        }
    를 받아서, 
    1) group_id_of_class[c] = g (어느 그룹에 속하는지)
    2) group_mask[g, c] = True/False 
    3) group_sizes[g] = 그 그룹의 클래스 수
    등을 텐서로 반환
    """
    # 1) 각 그룹에 인덱스 부여
    group_names = list(class_group.keys())
    group_to_id = {gname: idx for idx, gname in enumerate(group_names)}
    n_groups = len(group_names)
    
    # 2) (1) group_id_of_class 초기화: -1이면 "소속 그룹 없음"
    group_id_of_class = torch.full((num_classes,), -1, dtype=torch.long)
    
    for gname, cls_list in class_group.items():
        g_id = group_to_id[gname]
        for c in cls_list:
            group_id_of_class[c] = g_id
    
    # 3) group_mask[g, c] : g그룹에 c클래스가 속하면 True
    group_mask = torch.zeros(n_groups, num_classes, dtype=torch.bool)
    for gname, cls_list in class_group.items():
        g_id = group_to_id[gname]
        for c in cls_list:
            group_mask[g_id, c] = True
    
    # 4) group_sizes[g] : 해당 그룹의 클래스 개수
    group_sizes = group_mask.sum(dim=1)  # shape: (n_groups,)

    return group_id_of_class, group_mask, group_sizes