import numpy as np
import cv2


def _safe_mean(values: np.ndarray) -> float:
    if values.size == 0:
        return float("nan")
    return float(np.mean(values))


def _extract_log_edges(depth: np.ndarray, valid_mask: np.ndarray, eps: float) -> np.ndarray:
    if not np.any(valid_mask):
        return np.zeros_like(valid_mask, dtype=bool)

    log_depth = np.zeros_like(depth, dtype=np.float32)
    log_depth[valid_mask] = np.log(np.maximum(depth[valid_mask], eps)).astype(np.float32)

    dx = cv2.Sobel(log_depth, cv2.CV_32F, 1, 0, ksize=3)
    dy = cv2.Sobel(log_depth, cv2.CV_32F, 0, 1, ksize=3)
    mag = cv2.magnitude(dx, dy)

    kernel = np.ones((3, 3), np.uint8)
    inner_mask = cv2.erode(valid_mask.astype(np.uint8), kernel, iterations=1) > 0

    mag[~inner_mask] = 0.0
    valid_mag = mag[inner_mask]
    if valid_mag.size == 0:
        return np.zeros_like(valid_mask, dtype=bool)

    th = float(valid_mag.mean() + valid_mag.std())
    edges = mag > th
    return edges & inner_mask


def _compute_ibims_edge_metrics(pred: np.ndarray, gt: np.ndarray, valid_mask: np.ndarray, eps: float) -> tuple[float, float]:
    th_edges = 10.0

    gt_edges = _extract_log_edges(gt, valid_mask, eps)
    pred_edges = _extract_log_edges(pred, valid_mask, eps)

    if not np.any(gt_edges):
        return float("nan"), float("nan")

    d_target = cv2.distanceTransform((~gt_edges).astype(np.uint8), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)
    d_pred = cv2.distanceTransform((~pred_edges).astype(np.uint8), cv2.DIST_L2, cv2.DIST_MASK_PRECISE)

    pred_edges_close = pred_edges & (d_target < th_edges)
    has_close_edges = np.any(pred_edges_close)

    edge_acc = float(d_target[pred_edges_close].mean()) if has_close_edges else th_edges
    edge_comp = float(d_pred[gt_edges].mean()) if has_close_edges else th_edges
    return edge_acc, edge_comp


def compute_metrics(pred: np.ndarray, gt: np.ndarray, mask: np.ndarray, eps: float = 1e-6) -> dict:
    pred = pred.astype(np.float64)
    gt = gt.astype(np.float64)
    valid = mask & np.isfinite(pred) & np.isfinite(gt) & (gt > 0)

    if not np.any(valid):
        return {
            "delta_1": float("nan"),
            "delta_2": float("nan"),
            "delta_3": float("nan"),
            "mae": float("nan"),
            "abs_rel": float("nan"),
            "rmse": float("nan"),
            "silog": float("nan"),
            "irmse": float("nan"),
            "sq_rel": float("nan"),
            "edge_acc": float("nan"),
            "edge_comp": float("nan"),
        }

    pred_v = pred[valid]
    gt_v = gt[valid]

    ratio = np.maximum(gt_v / np.maximum(pred_v, eps), pred_v / np.maximum(gt_v, eps))
    delta_1 = _safe_mean(ratio < 1.25)
    delta_2 = _safe_mean(ratio < 1.25 ** 2)
    delta_3 = _safe_mean(ratio < 1.25 ** 3)

    mae = _safe_mean(np.abs(pred_v - gt_v))
    abs_rel = _safe_mean(np.abs(pred_v - gt_v) / np.maximum(gt_v, eps))
    rmse = float(np.sqrt(_safe_mean((pred_v - gt_v) ** 2)))

    log_pred = np.log(np.maximum(pred_v, eps))
    log_gt = np.log(np.maximum(gt_v, eps))
    d = log_pred - log_gt
    silog_var = max(_safe_mean(d ** 2) - _safe_mean(d) ** 2, 0.0)
    silog = float(np.sqrt(silog_var) * 100.0)

    inv_pred = 1.0 / np.maximum(pred_v, eps)
    inv_gt = 1.0 / np.maximum(gt_v, eps)
    irmse = float(np.sqrt(_safe_mean((inv_pred - inv_gt) ** 2)))

    sq_rel = _safe_mean(((pred_v - gt_v) ** 2) / np.maximum(gt_v, eps))
    edge_acc, edge_comp = _compute_ibims_edge_metrics(pred, gt, valid, eps)

    return {
        "delta_1": delta_1,
        "delta_2": delta_2,
        "delta_3": delta_3,
        "mae": mae,
        "abs_rel": abs_rel,
        "rmse": rmse,
        "silog": silog,
        "irmse": irmse,
        "sq_rel": sq_rel,
        "edge_acc": edge_acc,
        "edge_comp": edge_comp,
    }
