import torch
from typing import Optional, Tuple


def dice_coeff(x_hat, x, eps=1e-6):
    x_hat = torch.sigmoid(x_hat)
    intersection = (x_hat * x).sum()
    return (2 * intersection) / (x_hat.sum() + x.sum() + eps)


@torch.no_grad()
def average_precision_binary(
    scores: torch.Tensor,
    targets: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    from_logits: bool = True,
    eps: float = 1e-8,
) -> torch.Tensor:
    """
    Compute AP for a single sample (binary segmentation/classification over pixels).
    Uses the VOC/COCO-style precision envelope before integrating over recall.

    Args:
        scores: (1,H,W) or (H,W) float tensor. Model output per pixel.
        targets: same shape as scores, binary {0,1}.
        mask: optional same shape; pixels with mask==0 are ignored.
        from_logits: if True, apply sigmoid to scores.
        eps: numerical stability.

    Returns:
        ap: scalar tensor with Average Precision for this sample.
    """
    # Flatten
    s = scores.reshape(-1)
    y = targets.reshape(-1).to(torch.uint8)

    if mask is not None:
        m = mask.reshape(-1).to(torch.bool)
        s = s[m]
        y = y[m]

    if s.numel() == 0:
        return torch.tensor(0.0, device=scores.device)

    if from_logits:
        s = torch.sigmoid(s)

    # Sort by score descending
    order = torch.argsort(s, dim=0, descending=True)
    y_sorted = y[order].to(torch.float32)

    # Cumulative TP and FP
    tp = torch.cumsum(y_sorted, dim=0)
    fp = torch.cumsum(1.0 - y_sorted, dim=0)

    # Number of positives
    P = tp[-1].clamp(min=eps)
    # Precision and recall at each threshold
    precision = tp / (tp + fp + eps)
    recall = tp / P

    # Precision envelope (monotonic decreasing when traversed from low recall to high recall)
    # We compute it from the end to the start
    # Precision envelope: ensure precision is non-increasing w.r.t. recall
    precision_envelope = torch.flip(
        torch.cummax(torch.flip(precision, dims=[0]), dim=0)[0], dims=[0]
    )

    # Integrate AP as sum over recall steps (step-wise constant precision)
    # Insert sentinel (r=0,p=1) and (r=1,p=0) like VOC/COCO do
    mrec = torch.cat(
        [
            torch.tensor([0.0], device=recall.device),
            recall,
            torch.tensor([1.0], device=recall.device),
        ]
    )
    mpre = torch.cat(
        [
            torch.tensor([1.0], device=precision.device),
            precision_envelope,
            torch.tensor([0.0], device=precision.device),
        ]
    )

    # Area under curve via sum of precision * delta_recall
    # Find indices where recall changes
    idx = torch.where(mrec[1:] != mrec[:-1])[0]
    ap = torch.sum((mrec[idx + 1] - mrec[idx]) * mpre[idx + 1])

    return ap


@torch.no_grad()
def mean_average_precision_binary(
    scores: torch.Tensor,
    targets: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    from_logits: bool = True,
    eps: float = 1e-8,
    reduce: bool = True,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    """
    Compute AP per sample and mean across the batch.

    Args:
        scores: (B,1,H,W) float tensor of model outputs (logits or probs).
        targets: (B,1,H,W) binary ground truth {0,1}.
        mask: optional (B,1,H,W) binary mask of valid pixels (1=valid).
        from_logits: apply sigmoid to scores if True.
        eps: numerical stability.
        reduce: if True, return mean AP; else return (mean_ap, ap_per_sample)

    Returns:
        mean_ap if reduce=True,
        else (mean_ap, ap_per_sample) where ap_per_sample is (B,) tensor.
    """
    assert scores.shape == targets.shape, "scores and targets must have same shape"
    assert scores.dim() == 4 and scores.size(1) == 1, "expect (B,1,H,W)"

    B = scores.size(0)
    ap_list = []
    for b in range(B):
        ap = average_precision_binary(
            scores[b, 0],
            targets[b, 0],
            None if mask is None else mask[b, 0],
            from_logits=from_logits,
            eps=eps,
        )
        # ap = average_precision_score(targets.flatten(), scores.flatten())
        ap_list.append(ap)

    ap_per_sample = torch.stack(ap_list, dim=0)
    mean_ap = ap_per_sample.mean()

    if reduce:
        return mean_ap
    else:
        return mean_ap, ap_per_sample
