
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, roc_auc_score, roc_curve

def plot_training_curves(loss, val_loss, acc, val_acc, figsize):
    '''
    loss = list of loss
    val_loss = list of validation loss
    acc = list of accuracy
    val_acc = list of validation accuracy
    figsize = tuple of shape (w,h)
    '''
    plt.figure(figsize=figsize)
    # plot the loss curve
    plt.subplot(1, 2, 1)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('epoch')
    plt.ylabel('BCE-Dice')
    plt.ylim([0,1.0])
    plt.legend(loc='upper right')
    # plot the accuracy curve
    plt.subplot(1, 2, 2)
    plt.plot(acc, label='Training IoU')
    plt.plot(val_acc, label='Validation IoU')
    plt.title('Training and Validation IoU')
    plt.xlabel('epoch')
    plt.ylabel('IoU')
    plt.ylim([min(plt.ylim()),1])
    plt.legend(loc='lower right')
    # Show the figure
    plt.show()

def calc_score(y_true, y_pred, thresh):
    '''
    y_true = ground_truth masks [0,1]
    y_pred = predicted masks (probability values) [0-1]
    thresh = threshold for binarization
    '''
    report = classification_report((y_true>thresh).flatten(), (y_pred>thresh).flatten(), output_dict=True)
    Accuracy = report['accuracy']
    Precision = report['True']['precision']
    Recall = report['True']['recall']
    F1_score = report['True']['f1-score']
    Sensitivity = report['True']['recall']
    Specificity = report['False']['recall']
    AUC = roc_auc_score(y_true.flatten(), y_pred.flatten())
    IOU = (Precision*Recall)/(Precision+Recall-Precision*Recall)
    return [Accuracy, Precision, Recall, F1_score, Sensitivity, Specificity, AUC, IOU]

def print_score(scores):
    '''
    scores = [Accuracy, Precision, Recall, F1_score, Sensitivity, Specificity, AUC, IOU]
    '''
    print('Accuracy\t:\t{0:.4f}'.format(scores[0]))
    print('Precision\t:\t{0:.4f}'.format(scores[1]))
    print('Recall\t\t:\t{0:.4f}'.format(scores[2]))
    print('F1_score\t:\t{0:.4f}'.format(scores[3]))
    print('Sensitivity\t:\t{0:.4f}'.format(scores[4]))
    print('Specificity\t:\t{0:.4f}'.format(scores[5]))
    print('AUC\t\t:\t{0:.4f}'.format(scores[6]))
    print('IOU\t\t:\t{0:.4f}'.format(scores[7]))

def show_images(images, title, figsize, cmap='binary'):
    '''
    images: array like of shape [n,w,h,3] or [n,w,h]
    title: str
    figsize: tuple of shape (w,h)
    cmap: matplotlib cmap string
    '''
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=figsize)
    fig.suptitle(title, fontsize=15)
    axes = axes.flatten()
    for img, ax in zip(images, axes):
        ax.imshow(img, cmap=cmap)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

def plt_roc_curve(y_true, y_pred, figsize):
    '''
    y_true: ground truth values [0,1]
    y_pred: predicted probability values [0-1]
    '''
    fpr, tpr, thresholds = roc_curve(y_true.flatten(), y_pred.flatten())
    roc_auc = roc_auc_score(y_true.flatten(), y_pred.flatten())
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.plot(fpr, tpr, color='darkorange',lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.legend()
    plt.show()
