import torch
import numpy as np
import sklearn.metrics as metrics
from sklearn.metrics import roc_curve
from util.Eval import AverageCalculator


def log_value(probs, small_value=1e-30):
    """Compute log probability, avoiding numerical problems"""
    return -torch.log(torch.clamp(probs, min=small_value))


def confidence(outputs, labels):
    """Calculate confidence: Take the predicted probability of the correct category"""
    return torch.gather(outputs, 1, labels.long().view(-1, 1)).squeeze().detach()



def entropy(outputs,labels=None):
    """Calculate entropy"""
    return - torch.sum(outputs * log_value(outputs), dim=1).detach()



def modified_entropy(outputs, labels):
    """Calculate the modified entropy"""
    log_probs = log_value(outputs)
    reverse_probs = 1 - outputs
    log_reverse_probs = log_value(reverse_probs)

    modified_probs = outputs.clone()
    labels = labels.to(dtype=torch.int)
    modified_probs[torch.arange(labels.size(0)), labels] = reverse_probs[torch.arange(labels.size(0)), labels]

    modified_log_probs = log_reverse_probs.clone()
    modified_log_probs[torch.arange(labels.size(0)), labels] = log_probs[torch.arange(labels.size(0)), labels]

    return -torch.sum(modified_probs * modified_log_probs, dim=1).detach()


def grad_norm(outputs, labels):
    criterion = torch.nn.CrossEntropyLoss()
    if not outputs.requires_grad:
        outputs.requires_grad_(True)
    loss = criterion(outputs, labels)
    grads = torch.autograd.grad(loss, outputs, retain_graph=True, create_graph=True)[0]
    norm = torch.norm(grads, dim=1)
    return norm.detach()


def accuracy(result_member, result_non_member):
    tp = result_member.float().sum()
    fp = result_non_member.float().sum()
    tn = (1. - result_non_member.float()).sum()
    fn = (1. - result_member.float()).sum()
    acc = (tp + tn) / (tp+fp+tn+fn) if tp+fp+tn+fn > 0 else torch.tensor(0.0)
    return acc.item()


def precision_p(result_member, result_non_member):
    tp = result_member.float().sum()
    fp = result_non_member.float().sum()
    precision = tp / (tp + fp) if tp + fp > 0 else torch.tensor(0.0)
    return precision.item()


def recall_p(result_member, result_non_member):
    tp = result_member.float().sum()
    fn = (1. - result_member.float()).sum()
    recall = tp / (tp + fn) if tp + fn > 0 else torch.tensor(0.0)
    return recall.item()


def precision_n(result_member, result_non_member):
    tn = (1. - result_non_member).float().sum()
    fn = (1. - result_member.float()).sum()
    precision = tn / (tn + fn) if tn + fn > 0 else torch.tensor(0.0)
    return precision.item()


def recall_n(result_member, result_non_member):
    tn = (1. - result_non_member).float().sum()
    fp = result_non_member.float().sum()
    recall = tn / (tn + fp) if tn + fp > 0 else torch.tensor(0.0)
    return recall.item()


def f1_p(result_member, result_non_member):
    pr = precision_p(result_member, result_non_member)
    re = recall_p(result_member, result_non_member)
    f1 = 2 * (pr * re) / (pr + re) if pr + re > 0 else 0.0
    return f1


def f1_n(result_member, result_non_member):
    pr = precision_n(result_member, result_non_member)
    re = recall_n(result_member, result_non_member)
    f1 = 2 * (pr * re) / (pr + re) if pr + re > 0 else 0.0
    return f1


def auc(prob_member, prob_non_member):
    # prob_member and prob_non_member here are probabilities rather than binary labels
    y_true = torch.cat([torch.ones_like(prob_member), torch.zeros_like(prob_non_member)])
    y_scores = torch.cat([prob_member, prob_non_member])
    fpr, tpr, _ = metrics.roc_curve(y_true.cpu().numpy(), y_scores.cpu().numpy())
    return metrics.auc(fpr, tpr)

def tpr_fpr(prob_member, prob_non_member, threshold=0.0001):
    y_true = torch.cat([torch.ones_like(prob_member), torch.zeros_like(prob_non_member)])
    y_scores = torch.cat([prob_member, prob_non_member])
    fpr, tpr, _ = metrics.roc_curve(y_true.cpu().numpy(), y_scores.cpu().numpy())
    low = tpr[np.where(fpr < threshold)[0][-1]]
    return low



def get_metric(nonmember_acc, member_acc, 
               batch_result_member, batch_result_nonmember, 
               tosave):

    # current acc 
    
    acc_member = torch.mean(batch_result_member.float())
    acc_nonmember = 1. - torch.mean(batch_result_nonmember.float())

    member_acc.update(acc_member.item(), batch_result_member.size(0))
    nonmember_acc.update(acc_nonmember.item(), batch_result_nonmember.size(0))   

    tp = round(member_acc.sum)  # true positive
    fn = round(member_acc.count - member_acc.sum)  # false negative
    nonmember_num = nonmember_acc.count
    tn = round(nonmember_acc.sum)  # true negative
    fp = round(nonmember_num - nonmember_acc.sum)  # false positive


    # total member and non-member 
    result_member = torch.tensor(list(tosave['member_pred'].values()))
    result_nonmember = torch.tensor(list(tosave['nonmember_pred'].values()))
    acc = accuracy(result_member, result_nonmember)
    prec_p = precision_p(result_member, result_nonmember)
    rec_p = recall_p(result_member, result_nonmember)
    prec_n = precision_n(result_member, result_nonmember)
    rec_n = recall_n(result_member, result_nonmember)

    f1_positive = f1_p(result_member, result_nonmember)
    f1_negative = f1_n(result_member, result_nonmember)

    # TODO: Also Aggretation Result Here

    return {'tp': tp, 'tn': tn, 'fp': fp, 'fn': fn,
                       'member_acc': member_acc.average, 'nonmember_acc': nonmember_acc.average,
                       'accuracy': acc, 'precision@p': prec_p, 'recall@p': rec_p, 'precision@n': prec_n, 'recall@n': rec_n,
                       'f1@p': f1_positive, 'f1@n': f1_negative}
