function [p, aup, aupr, auc, pbcorr] = prediction_evaluation_20(scores, labels)

%%% INPUT %%%
% scores - numerical scores for the samples
% labels - binary labels indicating the positive and negative samples

%%% OUTPUT %%%
% p      - precision
% aup    - area under precision
% aupr   - area under precision-recall
% auc    - probability that a positive sample has higher score than a negative sample
% pbcorr - point-biserial correlation coefficient

validateattributes(scores, {'numeric'}, {'vector','finite'})
n = length(scores);
validateattributes(labels, {'numeric'}, {'vector','binary','numel',n})
n1 = sum(labels==1);
n0 = n - n1;
if n1==0 || n0==0
    error('labels cannot be all ones or all zeros')
end
if isrow(scores); scores = scores'; end
if isrow(labels); labels = labels'; end

% p, aup, aupr
[p, aup, aupr] = compute_p_aup_aupr(scores, labels, n, n1);

% auc
auc = compute_auc(scores, labels, n1, n0);


% pbcorr
pbcorr = compute_pbcorr(scores, labels, n, n1, n0);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [p, aup, aupr] = compute_p_aup_aupr(scores, labels, n, n1)

[scores,idx] = sort(-scores, 'ascend');
labels = labels(idx);
[~,ut,~] = unique(scores);
ut = [ut(2:end)-1; n];
tp = full(cumsum(labels));
recall = tp ./ sum(labels);
precision = tp ./ (1:n)';

% precision
p = precision(n1);

% aup
aup = trapz(1:n1,precision(1:n1)) / (n1-1);

% aupr
recall = recall(ut);
precision = precision(ut);
if all(recall==1)
    aupr = precision(1);
else
    aupr = trapz(recall,precision) / (1-recall(1));
end



%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% code: https://www.mathworks.com/matlabcentral/fileexchange/50962-fast-auc
% reference: http://www.springerlink.com/content/nn141j42838n7u21/fulltext.pdf

function auc = compute_auc(scores, labels, n1, n0)

ranks = tiedrank(scores);
auc = (sum(ranks(labels==1)) - n1*(n1+1)/2) / (n1*n0);

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function pbcorr = compute_pbcorr(scores, labels, n, n1, n0)

m1 = mean(scores(labels==1));
m0 = mean(scores(labels==0));
s = std(scores);
pbcorr = (m1-m0)/s * sqrt(n1*n0/(n*(n-1)));
