from config import dtype
import torch
import torch.nn.functional as F

def masked_softmax_cross_entropy(preds, labels, mask):
    """Softmax cross-entropy loss with masking."""
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    preds = preds.to(device)
    labels = labels.to(device)
    mask = mask.to(device)

    loss = F.cross_entropy(preds, torch.argmax(labels, dim=1), reduction='none')
    mask = mask.type(torch.float32)
    mask = mask / torch.mean(mask)
    loss = loss * mask
    return torch.mean(loss)

def softmax_cross_entropy(preds, labels):
    """Softmax cross-entropy loss with masking."""
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    preds = preds.to(device)
    labels = labels.to(device)

    loss = F.cross_entropy(preds, torch.argmax(labels, dim=1), reduction='none')
    return torch.mean(loss)

def masked_sigmoid_cross_entropy(preds, labels, mask):
    """Softmax cross-entropy loss with masking."""
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    preds = preds.to(device)
    labels = labels.to(device)
    mask = mask.to(device)

    logits = torch.tensor(preds)
    p = torch.tensor(labels)
    loss = p*-torch.log(torch.sigmoid(logits)) + (1-p)*-torch.log(1-torch.sigmoid(logits))
    mask = mask.type(torch.float32)
    mask = mask / torch.mean(mask)
    loss = loss * mask
    return torch.mean(loss)

def sigmoid_cross_entropy(preds, labels):
    """Softmax cross-entropy loss with masking."""
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    preds = preds.to(device)
    labels = labels.to(device)

    logits = torch.tensor(preds)
    p = torch.tensor(labels)
    loss = p*-torch.log(torch.sigmoid(logits)) + (1-p)*-torch.log(1-torch.sigmoid(logits))
    return torch.mean(loss)

def masked_accuracy(preds, labels, mask):
    """Accuracy with masking."""
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    preds = preds.to(device)
    labels = labels.to(device)
    mask = mask.to(device)

    correct_prediction = torch.eq(torch.argmax(preds, 1), torch.argmax(labels, 1))
    accuracy_all = correct_prediction.type(torch.float32)
    mask = mask.type(torch.float32)
    mask = mask / torch.mean(mask)
    accuracy_all = accuracy_all * mask
    return torch.mean(accuracy_all)

def accuracy(preds, labels):
    """Accuracy with masking."""
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    preds = preds.to(device)
    labels = labels.to(device)

    correct_prediction = torch.eq(torch.argmax(preds, 1), torch.argmax(labels, 1))
    accuracy_all = correct_prediction.type(torch.float32)
    return torch.mean(accuracy_all)