import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


@torch.no_grad()
def masking(logits_x_ulb, softmax_x_ulb=False, cutoff=None, *args, **kwargs):

    if softmax_x_ulb:
        probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1)
    else:
        # logits is already probs
        probs_x_ulb = logits_x_ulb.detach()
    max_probs, _ = torch.max(probs_x_ulb, dim=-1)
    mask = max_probs.ge(cutoff).to(max_probs.dtype)
    return mask


@torch.no_grad()
def inlier_p(logits_out):
    """
    Compute inlier mask from OVA logits.

    Args:
        logits_out (Tensor): Logits of shape [N, 2 * C] or [N, 2, C].
        in_cutoff (float): Confidence threshold to determine inliers.
        device (torch.device): Device to move tensors to.

    Returns:
        Tensor: Boolean mask of shape [N], where 1.0 means inlier.
    """
    if logits_out.ndim == 2:
        num_ulb = logits_out.shape[0]
        logits_out = logits_out.view(num_ulb, 2, -1)

    r = F.softmax(logits_out.detach(), dim=1)  # [N, 2, C]
    in_p = r[:, 1, :]                          # P(inlier) for each class
    return in_p.detach()


@torch.no_grad()
def p_masking(logits_x_ulb, softmax_x_ulb=False, *args, **kwargs):

    if softmax_x_ulb:
        probs_x_ulb = torch.softmax(logits_x_ulb.detach(), dim=-1)
    else:
        # logits is already probs
        probs_x_ulb = logits_x_ulb.detach()
    max_probs, pseudo_label = torch.max(probs_x_ulb, dim=-1)
    return max_probs.detach(), pseudo_label.detach()


@torch.no_grad()
def in_p_masking(logits_out):
    """
    Compute inlier mask from OVA logits.

    Args:
        logits_out (Tensor): Logits of shape [N, 2 * C] or [N, 2, C].
        in_cutoff (float): Confidence threshold to determine inliers.
        device (torch.device): Device to move tensors to.

    Returns:
        Tensor: Boolean mask of shape [N], where 1.0 means inlier.
    """
    if logits_out.ndim == 2:
        num_ulb = logits_out.shape[0]
        logits_out = logits_out.view(num_ulb, 2, -1)

    r = F.softmax(logits_out.detach(), dim=1)  # [N, 2, C]
    in_p = r[:, 1, :]                          # P(inlier) for each class
    inlier_probs, ova_pred = in_p.max(dim=1)

    return in_p, inlier_probs.detach(), ova_pred.detach()


@torch.no_grad()
def compute_pseudo_accuracy(p, y_ulb, mask):
    pseudo_labels = torch.argmax(p, dim=-1)

    # total correct
    tot_samples = y_ulb.shape[0]
    tot_correct = (pseudo_labels == y_ulb).sum().item()

    # masked correct
    mask_idx = mask.bool()
    masked_samples = int(mask_idx.sum().item())
    if masked_samples == 0:
        return tot_samples, tot_correct, 0, 0

    masked_correct = (pseudo_labels[mask_idx] == y_ulb[mask_idx]).sum().item()

    return tot_samples, tot_correct, masked_samples, masked_correct
