import torch
import torch.nn.functional as F


def metric_with_logits(pred, label, **kwargs):
    pred = F.sigmoid(pred)
    correct = (torch.where(pred > 0.5, 1, 0) == label).sum().item()
    return correct, len(pred)


def ranking_metric(pred, label, **kwargs):
    pred1, pred2 = pred
    correct = (torch.where(pred1 < pred2, -1, 1) == label).sum().item()
    return correct, len(label)


def list_ranking_metric(pred, scoring_label, scoring_mask, **kwargs):
    _, decimation_pred = pred
    # Filter rows where we do not have any ranking labels at all
    decimation_pred = decimation_pred[torch.any(scoring_mask == 1, dim=1)]
    scoring_label = scoring_label[torch.any(scoring_mask == 1, dim=1)]
    scoring_mask = scoring_mask[torch.any(scoring_mask == 1, dim=1)]

    # mask out the choices for which we do not have a label
    decimation_pred[scoring_mask == 0] = -float("inf")
    scoring_label[scoring_mask == 0] = -float("inf")
    scoring_argmax = scoring_label.argmax(dim=1)
    _, top_k_indices = torch.topk(decimation_pred, k=3, dim=1)

    correct = (
        top_k_indices.eq(scoring_argmax.view(-1, 1).expand_as(top_k_indices))
        .float()
        .sum()
        .item()
    )
    return correct, scoring_label.size(0)


def list_ranking_metric_v2(pred, scoring_label, scoring_mask, **kwargs):
    _, decimation_pred = pred
    decimation_pred = decimation_pred[torch.any(scoring_mask == 1, dim=1)]
    scoring_label = scoring_label[torch.any(scoring_mask == 1, dim=1)]
    scoring_mask = scoring_mask[torch.any(scoring_mask == 1, dim=1)]
    decimation_pred_prob = F.softmax(decimation_pred, dim=1)
    # mask out the choices for which we do not have a label
    decimation_pred_prob[scoring_mask == 0] = 0
    scoring_label[scoring_mask == 0] = -1e6  # A large negative number
    l1 = torch.argsort(decimation_pred, dim=1, descending=True)
    l2 = torch.argsort(scoring_label, dim=1, descending=True)
    correct = (l1 == l2).type(torch.float).sum().item()
    total = len(l1)
    return correct, total


def list_bc_metric_v2(pred, ml_class_label, **kwargs):
    optimality_pred, _ = pred
    pred_prob = F.sigmoid(optimality_pred)
    classes = torch.where(pred_prob > 0.5, 1, 0)
    correct = ml_class_label[classes == 1].sum().item()
    total = ml_class_label[ml_class_label == 1].sum().item()
    return correct, total


def mse_metric(pred, child_nodes, **kwargs):
    return torch.sqrt(F.mse_loss(pred, child_nodes)), len(pred)
