import torch


def sort_by_label_and_verifier_preds(mat, ys_true, v_preds, max_per_category, concat=True):
    # (label, v_pred) order: 00, 01, 10, 11
    idx00 = (ys_true == torch.zeros_like(ys_true)) * (v_preds == torch.zeros_like(v_preds))
    idx01 = (ys_true == torch.zeros_like(ys_true)) * (v_preds == torch.ones_like(v_preds))
    idx10 = (ys_true == torch.ones_like(ys_true)) * (v_preds == torch.zeros_like(v_preds))
    idx11 = (ys_true == torch.ones_like(ys_true)) * (v_preds == torch.ones_like(v_preds))
    if concat:
        lst = (mat[idx00][:max_per_category], mat[idx01][:max_per_category],
               mat[idx10][:max_per_category], mat[idx11][:max_per_category])
        quantities = list()
        for elm in lst:
            quantities.append(elm.shape[0])
        return torch.cat(lst, dim=0), quantities
    else:
        return dict(idx00=mat[idx00][:max_per_category], idx01=mat[idx01][:max_per_category],
                    idx10=mat[idx10][:max_per_category], idx11=mat[idx11][:max_per_category])
