import torch


def accuracy(model, dataloader, device=torch.device("cuda:0")):
    model.to(device)
    correct = 0
    total = 0
    accs = []
    with torch.no_grad():
        for data in dataloader:
            texts, masks, labels = (
                data["input_ids"],
                data["attention_mask"],
                data["labels"],
            )
            texts, masks, labels = texts.to(device), masks.to(device), labels.to(device)
            logits = model(input_ids=texts, attention_mask=masks)["logits"]
            labels_ = logits.argmax(axis=1)
            correct += (labels_ == labels).sum()
            total += len(labels)
            print(correct / total)
    acc_avg = correct / total
    return acc_avg


def accuracy_casual(model, dataloader, device=torch.device("cuda:0")):
    # model.to(device).eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            input_ids, attention_mask, labels = (
                data["input_ids"].to(device),
                data["attention_mask"].to(device),
                data["labels"].to(device),
            )
            output = model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
            preds = output["logits"].argmax(axis=2)
            for pred, label in zip(preds, labels):
                pred, label = pred[:-1], label[1:]
                mask = label != -100
                pred, label = pred[mask], label[mask]
                correct += (pred == label).sum()
                total += label.numel()
            print(correct / total)
    return correct / total
