"""Basic metric functions that work on predictions and labels."""

import numpy as np
from typing import List, Dict, Optional
from sklearn.metrics import roc_auc_score, brier_score_loss

from calib.common_train_utils import compute_ece


def compute_roc_auc(labels: np.ndarray, probs: np.ndarray) -> float:
    """Compute ROC AUC if both classes are present."""
    return roc_auc_score(labels, probs)

def compute_brier_score(labels: np.ndarray, probs: np.ndarray) -> float:
    """Compute Brier score."""
    return brier_score_loss(labels, probs)


def compute_ece_score(labels: np.ndarray, probs: np.ndarray) -> float:
    """Compute Expected Calibration Error."""
    return compute_ece(probs, labels)

def compute_nll(labels: np.ndarray, probs: np.ndarray) -> float:
    """Compute Negative Log-Likelihood.
    
    Args:
        labels: Binary labels (0 or 1)
        probs: Predicted probabilities
        
    Returns:
        Negative log-likelihood
    """
    # Clip probabilities to avoid log(0)
    eps = 1e-4
    probs = np.clip(probs, eps, 1 - eps)
    
    # Compute negative log-likelihood
    nll = -np.mean(labels * np.log(probs) + (1 - labels) * np.log(1 - probs))
    return nll

def compute_roc_auc_per_prompt(
    labels_by_prompt: List[List],
    probs_by_prompt: List[List]
) -> float:
    """Compute average ROC AUC across prompts.
    
    Args:
        labels_by_prompt: Dictionary mapping prompt_idx to list of labels
        probs_by_prompt: Dictionary mapping prompt_idx to list of probabilities
        
    Returns:
        Average AUC across prompts, or NaN if no valid AUCs
    """
    valid_aucs = []
    
    for group_labels, group_probs in zip(labels_by_prompt, probs_by_prompt):
        group_labels = np.array(group_labels)
        group_probs = np.array(group_probs)
        
        # Skip groups with all 0s or all 1
        if len(np.unique(group_labels)) > 1:
            auc = compute_roc_auc(group_labels, group_probs)
            if np.isnan(auc):
                raise ValueError("Invalid AUC")
            valid_aucs.append(auc)
    
    return np.mean(valid_aucs) if len(valid_aucs) > 0 else float("nan")

def compute_argmax_accuracy(
    labels_by_group: List[List],
    probs_by_group: List[List],
) -> Dict[str, float]:
    """Compute best-of-n metrics for different values of n."""
    all_best_rewards = []
    
    for group_labels, group_probs in zip(labels_by_group, probs_by_group):
        group_labels, group_probs = np.array(group_labels), np.array(group_probs)
        all_best_rewards.append(group_labels[np.argmax(group_probs)])
    
    return np.mean(all_best_rewards)