import torch
def get_weight(dataset_raw):

    if dataset_raw.num_classes == 2:
        num_pos = (dataset_raw.data.y == 1).sum().item()
        num_neg = (dataset_raw.data.y == 0).sum().item()
        weight = num_neg / num_pos
    else:
        weight = torch.nn.functional.one_hot(dataset_raw.data.y, num_classes=dataset_raw.num_classes).sum(0).float()
        weight  = weight.sum()/weight
        weight  = weight/ weight.sum()

        # weight = torch.nn.functional.softmax(-weight.float()) * len(weight)
        weight = weight.flatten().tolist()
    return weight