import numpy as np
import torch
import torch.nn.functional as F


def get_msp_score(logits, snr, w=0.3):
    scores = np.max(F.softmax(logits, dim=1).detach().cpu().numpy(), axis=1)
    return np.power(scores, w) * np.power(snr.cpu().numpy(), 1 - w)


def get_energy_score(logits, snr, w=0.5):
    scores = torch.logsumexp(logits.data.cpu(), dim=1).numpy()
    return np.power(scores, w) * np.power(snr.cpu().numpy(), 1 - w)

def get_snr_score(snr):
    return snr.cpu().numpy()

def get_score(logits, snr, method, w = 0.5):
    if method == "msp":
        return get_msp_score(logits, snr, w=w)
    if method == "energy": 
        return get_energy_score(logits, snr, w=w)
    if method == 'snr':
        return get_snr_score(snr)
    exit('Unsupported scoring method')

