import os
import sys
import torch
import torchvision
import numpy as np
from torchvision import datasets, models, transforms
import torch.nn as nn
from sklearn.metrics import roc_auc_score, roc_curve, auc
from torch.autograd import Variable
import time
from sklearn.metrics import confusion_matrix, roc_curve, auc
import datetime

def compute_measures1(labels, preds, predict_prob):
    '''
    :param labels: ground truth
    :param preds: predicted label
    :param predict_prob: postive probability (1)
    :return: ACC, SEN, SPE, AUC
    '''
    labels = labels.cpu().numpy()
    preds = preds.cpu().numpy()
    predict_prob = predict_prob.cpu().detach()
    cm = confusion_matrix(labels, preds)
    TP = float(cm[1][1])
    TN = float(cm[0][0])
    FP = float(cm[0][1])
    FN = float(cm[1][0])

    sensitivity = TP / ((TP + FN) + 1e-8)
    specificity = TN / (TN + FP + 1e-8)
    accuracy = (TP + TN) / (TP + TN + FP + FN)
    fpr, tpr, thresholds = roc_curve(labels, predict_prob)
    roc_auc = auc(fpr, tpr)
    return round(sensitivity,4), round(specificity,4), round(roc_auc,4), round(accuracy,4)