import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os

from sklearn.metrics import average_precision_score, roc_auc_score

# PU loss function
def pu_loss_ranking_multi(y_pred, y_obs, prior=0.1, margin_factor=0.0):
  ''' y_pred: (batch_size, n_labels), logits before sigmoid [N, L]
      y_obs: (batch_size, n_labels), partially labeled positives (1=positive, 0=unlabeled)
      prior: class prior (π), can be a scalar or a tensor of shape [L]
  '''
  # masks
  pos_mask = (y_obs == 1).float()  # [N, L]
  unlabel_mask = (y_obs == 0).float()  # [N, L]

  # per-label counts
  pos_counts = pos_mask.sum(dim=0).clamp(min=1.0)  # [L]
  un_counts = unlabel_mask.sum(dim=0).clamp(min=1.0)  # [L]

  # --- positive risk (log p) ---
  p_above = - (F.logsigmoid(y_pred) * pos_mask).sum(dim=0) / pos_counts
  # --- positive as negative risk (log (1-p)) ---
  p_below = - (F.logsigmoid(-y_pred) * pos_mask).sum(dim=0) / pos_counts
  # --- unlabeled risk with ranking ---
  u_below = - (F.logsigmoid(y_pred * pos_mask - y_pred*unlabel_mask)).sum(dim=0) / un_counts

  # handle prior: support scalar or vector
  if isinstance(prior, float) or isinstance(prior, int):
    prior_vec = torch.full((y_pred.shape[1],), float(prior), device=y_pred.device)
  else:
    prior_vec = prior.to(y_pred.device)
    if prior_vec.dim() == 0:
      prior_vec = prior_vec.repeat(y_pred.shape[1])

  margin_vec = prior_vec * margin_factor
  # risk per label
  risk_per_label = prior_vec * p_above + torch.relu(u_below - prior_vec * p_below + margin_vec)
  # sum over labels
  loss = risk_per_label.sum()

  return loss


def ranking_val_loss(y_pred, y_true, average="macro"):
    """
    Validation ranking loss based on Average Precision (AP).
    - y_pred: logits [N, L]
    - y_true: labels [N, L], (1=positive, 0=negative/unlabeled) make sure the validation set contains negatives
    Returns: scalar loss (lower is better)
    """
    y_prob = torch.sigmoid(y_pred).detach().cpu().numpy()
    y_true = y_true.detach().cpu().numpy()

    try:
        ap = average_precision_score(y_true, y_prob, average=average)
    except ValueError:
        ap = 0.0  # If a label is all-positive or all-negative, this will raise an error; set AP to 0 here

    # Alternatively, you can use 1 - AUC
    # auc = roc_auc_score(y_true, y_prob, average=average)
    return 1.0 - ap   # lower is better
