import numpy as np
from scipy import stats
from sklearn import metrics
import torch

def d_prime(auc):
    standard_normal = stats.norm()
    d_prime = standard_normal.ppf(auc) * np.sqrt(2.0)
    return d_prime

def calculate_stats(output, target):
    """Calculate statistics including mAP, AUC, etc.

    Args:
      output: 2d array, (samples_num, classes_num)
      target: 2d array, (samples_num, classes_num)

    Returns:
      stats: list of statistic of each class.
    """

    classes_num = target.shape[-1]
    stats = []

    # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
    acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))

    # Class-wise statistics
    non_class_cnt = 0
    for k in range(classes_num):

        # Check if class k has any true samples
        if torch.sum(target[:, k]) > 0:
            # Average precision
            avg_precision = metrics.average_precision_score(
                target[:, k], output[:, k], average=None)
            # AUC
            try:
                auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)

                # Precisions, recalls
                (precisions, recalls, thresholds) = metrics.precision_recall_curve(
                    target[:, k], output[:, k])

                # FPR, TPR
                (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])

                save_every_steps = 1000  # Sample statistics to reduce size
                dict = {'precisions': precisions[0::save_every_steps],
                        'recalls': recalls[0::save_every_steps],
                        'AP': avg_precision,
                        'fpr': fpr[0::save_every_steps],
                        'fnr': 1. - tpr[0::save_every_steps],
                        'auc': auc,
                        # note acc is not class-wise, this is just to keep consistent with other metrics
                        'acc': acc
                        }
                # print('class {:s} has true sample'.format(str(k)))
            except:
                dict = {'precisions': -1,
                        'recalls': -1,
                        'AP': avg_precision,
                        'fpr': -1,
                        'fnr': -1,
                        'auc': -1,
                        # note acc is not class-wise, this is just to keep consistent with other metrics
                        'acc': acc
                        }
                # print('class {:s} no true sample'.format(str(k)))
        else:
            non_class_cnt += 1
            continue
        stats.append(dict)
    # print("{} classes have no samples".format(non_class_cnt))
    return stats

def calculate_group_stats(output_all, target_all, group_idx):
    """Calculate statistics including mAP, AUC, etc.

    Args:
      output: 2d array, (samples_num, classes_num)
      target: 2d array, (samples_num, classes_num)

    Returns:
      stats: list of statistic of each class.
    """

    classes_num = target_all.shape[-1]
    stats = []

    for group in range(5):
        stats_group = []
        target = target_all[group_idx[group]]
        output = output_all[group_idx[group]]
        # Accuracy, only used for single-label classification such as esc-50, not for multiple label one such as AudioSet
        acc = metrics.accuracy_score(np.argmax(target, 1), np.argmax(output, 1))

        # Class-wise statistics
        non_class_cnt = 0
        for k in range(classes_num):

            # Check if class k has any true samples
            if torch.sum(target[:, k]) > 0:
                # Average precision
                avg_precision = metrics.average_precision_score(
                    target[:, k], output[:, k], average=None)
                # AUC
                try:
                    auc = metrics.roc_auc_score(target[:, k], output[:, k], average=None)

                    # Precisions, recalls
                    (precisions, recalls, thresholds) = metrics.precision_recall_curve(
                        target[:, k], output[:, k])

                    # FPR, TPR
                    (fpr, tpr, thresholds) = metrics.roc_curve(target[:, k], output[:, k])

                    save_every_steps = 1000  # Sample statistics to reduce size
                    dict = {'precisions': precisions[0::save_every_steps],
                            'recalls': recalls[0::save_every_steps],
                            'AP': avg_precision,
                            'fpr': fpr[0::save_every_steps],
                            'fnr': 1. - tpr[0::save_every_steps],
                            'auc': auc,
                            # note acc is not class-wise, this is just to keep consistent with other metrics
                            'acc': acc
                            }
                    # print('class {:s} has true sample'.format(str(k)))
                except:
                    dict = {'precisions': -1,
                            'recalls': -1,
                            'AP': avg_precision,
                            'fpr': -1,
                            'fnr': -1,
                            'auc': -1,
                            # note acc is not class-wise, this is just to keep consistent with other metrics
                            'acc': acc
                            }
                    # print('class {:s} no true sample'.format(str(k)))
            else:
                non_class_cnt += 1
                continue
            stats_group.append(dict)
        print("{} classes have no samples".format(non_class_cnt))
        stats.append(stats_group)
    return stats