import torch
from torch_geometric.graphgym.register import register_metric
from torch_geometric.graphgym.config import cfg


@register_metric('classification_multilabel')
def classification_multilabel(true, pred, task_type):
    assert task_type == 'classification_multilabel', f'task_type has to be classification_multilabel. ' \
                                                     f'{task_type} is given.'
    from sklearn.metrics import accuracy_score
    true, pred_score = torch.cat(true), torch.cat(pred)
    pred_int = (pred_score > cfg.model.thresh).long()
    return {'accuracy': round(accuracy_score(true, pred_int), cfg.round)}
