import numpy as np
import torch
import time
from tqdm import tqdm

def compute_ap_binned_fixed(y_true, y_score, n_bins=10000):
    """Computes Average Precision (AP) using a binned histogram approach for efficiency.

    This method bins predicted scores to accelerate AP calculation, then interpolates
    precision values at 101 recall points (0 to 1 in 0.01 increments) as specified
    in PASCAL VOC evaluation guidelines.

    Args:
        y_true: 1D numpy array of ground truth binary labels (0 or 1).
        y_score: 1D numpy array of predicted scores/probabilities (float in [0, 1]).
        n_bins: Number of bins to use for histogramming scores. Defaults to 10000.

    Returns:
        float: Average Precision (AP) score averaged over 101 recall points.
    """
    bin_indices = np.floor(y_score * n_bins).astype(int)
    bin_indices = np.clip(bin_indices, 0, n_bins-1)
    tp_hist = np.bincount(bin_indices, weights=y_true, minlength=n_bins)
    fp_hist = np.bincount(bin_indices, weights=1-y_true, minlength=n_bins)
    tp_cumsum = np.cumsum(tp_hist[::-1])
    fp_cumsum = np.cumsum(fp_hist[::-1])
    precision = tp_cumsum / (tp_cumsum + fp_cumsum + 1e-8)
    recall = tp_cumsum / (tp_hist.sum() + 1e-8)
    precision = precision[::-1]
    recall = recall[::-1]
    recall_points = np.linspace(0, 1, 101)
    precisions = np.zeros_like(recall_points)
    for i, r in enumerate(recall_points):
        precisions[i] = precision[recall >= r].max() if np.any(recall >= r) else 0
    ap = precisions.mean()
    return ap


def calculate_ap_with_confidence(
    pred_prob_maps,
    gt_masks,
):
    """Calculates Average Precision (AP) from predicted probability maps and ground truth masks.

    Converts predicted probability maps and ground truth masks into flattened arrays of
    scores and binary labels, then computes AP using the binned histogram method.

    Args:
        pred_prob_maps: List of PyTorch tensors representing predicted probability maps.
            Each tensor has shape (H, W) with values in [0, 1].
        gt_masks: List of PyTorch tensors representing ground truth masks. Each tensor
            has shape (H, W) with binary values (0 or 1).

    Returns:
        float: Average Precision (AP) score computed over all pixels in the input maps.
    """
    if len(pred_prob_maps) != len(gt_masks):
        raise ValueError()

    if len(gt_masks) == 0:
        return 1.0, [0], [1]

    y_score = torch.cat([p.reshape(-1) for p in pred_prob_maps]).cpu().numpy()
    y_true = torch.cat([(mask_gt > 0.5).int().reshape(-1) for mask_gt in gt_masks]).cpu().numpy()
    ap = compute_ap_binned_fixed(y_true, y_score)
    return ap

