import numpy as np

from .metrics import compute_metrics
from .regions import split_by_distance_bins


def apply_max_depth_mask(gt: np.ndarray, base_mask: np.ndarray, max_depth: float | None) -> np.ndarray:
    if max_depth is None:
        return base_mask
    return base_mask & np.isfinite(gt) & (gt > 0) & (gt <= max_depth)


def build_distance_masks(
    gt: np.ndarray,
    base_mask: np.ndarray,
    distance_bins: list[tuple[float, float | None]] | None,
    bin_names: list[str] | None = None,
) -> dict:
    if not distance_bins:
        return {}
    masks = split_by_distance_bins(gt, base_mask, distance_bins)
    if bin_names is None:
        return {f"bin_{i}": mask for i, mask in enumerate(masks)}
    return {name: mask for name, mask in zip(bin_names, masks)}


def compute_metrics_by_region(pred: np.ndarray, gt: np.ndarray, base_mask: np.ndarray, region_masks: dict) -> dict:
    results = {}
    results["overall"] = compute_metrics(pred, gt, base_mask)
    for name, mask in region_masks.items():
        results[name] = compute_metrics(pred, gt, mask)
    return results
