
import torch

#  Maximum Softmax Probability Detector
def msp(dataset_in, dataset_out, net, device):
    dataset_out_len = len(dataset_out.test_loader.dataset)
    dataset_in_len = len(dataset_in.test_loader.dataset)
    fc_w = net.linear.weight.data.clone().detach().to(device)
    fc_b = net.linear.bias.data.clone().detach().to(device)

    fc_w_norm = torch.nn.functional.normalize(fc_w, dim=1)

    with torch.no_grad():
        pred = torch.zeros((dataset_in_len + dataset_out_len)).to(device)
        y = torch.zeros_like(pred).to(device)
        index = 0
        # Test OOD Dataset
        for batch_idx, (data, labels) in enumerate(dataset_out.test_loader):
            data = data.to(device)
            labels = labels.to(device)

            out, fet = net(data, latent=True)

            norm_fet = torch.nn.functional.normalize(fet, dim=1)
            cos_theta = torch.mm(norm_fet, fc_w_norm.T).max(dim=1)[0]
            smax = torch.nn.functional.softmax(out, dim=1).max(dim=1)[0]

            score = smax + cos_theta

            pred[index: index + data.shape[0]] = score
            y[index: index + data.shape[0]] = torch.ones_like(labels).to(device)
            index += data.shape[0]

        # Test in distribution
        for batch_idx, (data, labels) in enumerate(dataset_in.test_loader):
            data = data.to(device)
            labels = labels.to(device)

            out, fet = net(data, latent=True)

            norm_fet = torch.nn.functional.normalize(fet, dim=1)
            cos_theta = torch.mm(norm_fet, fc_w_norm.T).max(dim=1)[0]
            smax = torch.nn.functional.softmax(out, dim=1).max(dim=1)[0]

            score = smax + cos_theta

            pred[index: index + data.shape[0]] = score

            index += data.shape[0]

    labels = y.cpu().numpy()
    pred = -pred.cpu().numpy()
    return labels, pred