import torch
import numpy as np
import torch.nn.functional as F

def softmax(preact):
    exponents = np.exp(preact)
    sum_exponents = np.sum(exponents, axis=1, keepdims=True) 
    return exponents/sum_exponents

def inverse_softmax(preds):
    # preds[preds==0.0] = 1e-40
    preds = preds/np.sum(preds, axis=1, keepdims=True)
    return np.log(preds) - np.mean(np.log(preds),axis=1, keepdims=True)

def get_probs(net, loader, device): 
    net.eval()

    probs = None
    labels = None

    with torch.no_grad():
        for batch_idx, batch_data in enumerate(loader):
            if batch_data[0].dim() == 1:
                inputs, targets = batch_data[1], batch_data[2]
            else:   
                inputs, targets = batch_data[0], batch_data[1]
            inputs = inputs.to(device)
            outputs = net.predict(inputs)
            if probs is None: 
                probs = F.softmax(outputs, dim = -1).detach().cpu().numpy()
            else:
                probs = np.concatenate((probs,F.softmax(outputs, dim = -1).detach().cpu().numpy()), axis=0)
            
            if targets is not None:
                if labels is None: 
                    labels = targets.cpu().numpy()
                else: 
                    labels = np.concatenate((labels,targets.cpu().numpy()), axis=0)


    return probs, labels

# Baseline1 : DoC

def get_DoC(train_probs, test_probs): 
    # Difference Between AC (Average Confidence) of Train and Test
    train_max_probs = np.max(train_probs, axis=-1)
    test_max_probs = np.max(test_probs, axis=-1)

    return np.mean(train_max_probs) - np.mean(test_max_probs)

def get_DoE(train_probs, test_probs):
    # Difference Between Average Entropy of Train and Test
    train_entropy = - train_probs * np.log(train_probs)
    test_entropy = - test_probs * np.log(test_probs)
    
    return np.mean(train_entropy) - np.mean(test_entropy)

# Baseline2: ATC 
def get_entropy(probs): 
	return np.sum( np.multiply(probs, np.log(probs + 1e-20))  , axis=1)

def get_max_conf(probs):
	return np.max(probs, axis=-1)

def find_ATC_threshold(scores, labels): 
    sorted_idx = np.argsort(scores)
    
    sorted_scores = scores[sorted_idx]
    sorted_labels = labels[sorted_idx]
    
    fp = np.sum(labels==0)
    fn = 0.0
    
    min_fp_fn = np.abs(fp - fn)
    thres = 0.0
    for i in range(len(labels)): 
        if sorted_labels[i] == 0: 
            fp -= 1
        else: 
            fn += 1
        
        if np.abs(fp - fn) < min_fp_fn: 
            min_fp_fn = np.abs(fp - fn)
            thres = sorted_scores[i]
    
    return min_fp_fn, thres

def get_ATC_acc(thres, scores): 
    return np.mean(scores>=thres)*100.0