import re
from typing import Iterable, List, Optional, Sequence, Tuple, Union

import cv2
import numpy as np


# ---------- 工具：2D soft-argmax（稳定求关键点坐标） ----------
def soft_argmax2d(hm: np.ndarray, beta: float = 10.0) -> Tuple[float, float]:
    """
    输入: hm (H, W)，建议为[0,1]或未归一化响应；beta为温度(越大越接近argmax)。
    输出: (y, x) 连续坐标（浮点）。
    """
    h, w = hm.shape
    # 数值稳定：减去最大值；再做softmax
    a = hm - np.max(hm)
    a = np.exp(beta * a)
    a_sum = a.sum()
    if a_sum <= 1e-12:  # 全零等退化
        yx = np.unravel_index(np.argmax(hm), hm.shape)
        return float(yx[0]), float(yx[1])
    a = a / a_sum
    # 期望坐标
    yy, xx = np.meshgrid(
        np.arange(h, dtype=np.float32), np.arange(w, dtype=np.float32), indexing="ij"
    )
    y = float((a * yy).sum())
    x = float((a * xx).sum())
    return y, x


# ---------- 工具：基于名称选择通道 ----------
_DEFAULT_INCLUDE_PATTERNS = [
    # 眉/眼/眼睑/睫毛/虹膜/瞳孔
    r"eyebrow",
    r"eyelid",
    r"lash",
    r"iris",
    r"pupil",
    # 鼻
    r"nose",
    r"nose_bridge",
    r"glabella",
    # 嘴/唇/法令沟/颏唇沟/人中
    r"mouth",
    r"lip",
    r"philtrum",
    r"labiomental",
    r"nasolabial",
    # 下巴
    r"chin",
]

_DEFAULT_EXCLUDE_PATTERNS = [
    r"\bear\b",
    r"helix",
    r"concha",
    r"tragus",
    r"antihelix",
    r"ear_lobe",
]


def _select_indices_from_names(
    kpt_names: Sequence[str],
    include_patterns: Optional[Sequence[str]] = None,
    exclude_patterns: Optional[Sequence[str]] = None,
) -> List[int]:
    include_patterns = list(include_patterns or _DEFAULT_INCLUDE_PATTERNS)
    exclude_patterns = list(exclude_patterns or _DEFAULT_EXCLUDE_PATTERNS)

    inc = [re.compile(p, re.IGNORECASE) for p in include_patterns]
    exc = [re.compile(p, re.IGNORECASE) for p in exclude_patterns]

    sel: List[int] = []
    for i, name in enumerate(kpt_names):
        if any(r.search(name) for r in exc):
            continue
        if any(r.search(name) for r in inc):
            sel.append(i)

    # 额外排除粗粒度的 COCO 风格耳朵点（若存在于前若干名）
    for i, name in enumerate(kpt_names[:10]):  # nose/eyes/ears 通常在前部
        if re.fullmatch(r"(left_ear|right_ear)", name, flags=re.IGNORECASE):
            if i in sel:
                sel.remove(i)
    return sel


def _ensure_4d(x: np.ndarray) -> Tuple[np.ndarray, bool]:
    """
    接受 (K,H,W) 或 (T,K,H,W)。统一成 (T,K,H,W)，并返回是否原本是3D。
    """
    if x.ndim == 3:  # K,H,W
        x = x[None, ...]
        return x, True
    if x.ndim != 4:
        raise ValueError(f"heatmap must be KxHxW or T×K×H×W, got {x.shape}")
    return x, False


def _largest_cc(mask: np.ndarray, min_area: int = 64) -> np.ndarray:
    """保留最大连通域，并去除面积过小的噪声。mask: 0/1 uint8。"""
    num, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=4)
    if num <= 1:
        return mask
    # 去除小区域
    areas = stats[1:, cv2.CC_STAT_AREA]
    if areas.size == 0:
        return mask
    largest = 1 + np.argmax(areas)
    out = (labels == largest).astype(np.uint8)
    # 同时滤除 <min_area 的小连通域（保险）
    for i in range(1, num):
        if i != largest and stats[i, cv2.CC_STAT_AREA] < min_area:
            out[labels == i] = 0
    return out


# ---------- 主函数：生成面部 mask ----------
def heatmap2facemask(
    heatmap: np.ndarray,
    kpt_names: Optional[Sequence[str]] = None,
    face_indices: Optional[Iterable[int]] = None,
    threshold: Union[float, str] = "otsu",  # "otsu" 或 [0,1] 浮点阈值
    morph_kernel: int = 5,
    min_area: int = 64,
    clip_to_face_via_landmarks: bool = True,
    forehead_ratio: float = 0.35,  # 从眉间到下巴的距离，向上外推的额头比例
) -> np.ndarray:
    """
    输入:
      - heatmap: (K,H,W) 或 (T,K,H,W)，每通道为[0,1]或未归一化响应。
      - kpt_names: 若提供则用名称筛人脸通道并排除耳朵；否则使用 face_indices。
      - face_indices: 直接指定用于生成脸部响应的通道下标集合（将覆盖 kpt_names）。
      - threshold: 'otsu' 或数值阈值（相对融合响应的 0~1）。
      - clip_to_face_via_landmarks: 若有 chin/glabella/outer_eye 等名称，可用软坐标裁耳并加额头。
    输出:
      - 二值面部 mask：与输入对应的 (H,W) 或 (T,H,W)；不含耳朵、覆盖下巴至额头。
    """
    x, was_3d = _ensure_4d(np.asarray(heatmap))
    T, K, H, W = x.shape

    # 选择通道
    if face_indices is None:
        if kpt_names is None:
            raise ValueError("需提供 kpt_names 或 face_indices 以筛选脸部通道。")
        face_indices = _select_indices_from_names(kpt_names)

    face_indices = list(face_indices)
    if len(face_indices) == 0:
        raise ValueError("筛选到的人脸通道为空，请检查 kpt_names/face_indices。")

    # 融合人脸响应（max 更保边界，sum 更保连通；此处取max）
    fused = np.max(x[:, face_indices, :, :], axis=1)  # (T,H,W)

    # 归一化到[0,1]
    f_min = fused.min(axis=(1, 2), keepdims=True)
    f_max = fused.max(axis=(1, 2), keepdims=True)
    denom = np.maximum(f_max - f_min, 1e-6)
    fused01 = (fused - f_min) / denom

    masks = np.zeros((T, H, W), dtype=np.uint8)

    # 可选：基于关键脸部点来做上下/左右裁剪，避免耳朵并加额头
    chin_idx = None
    glab_idx = None
    l_outer_idx = None
    r_outer_idx = None
    if clip_to_face_via_landmarks and (kpt_names is not None):
        # 名称匹配（依赖 Sapiens 列表）
        for i, n in enumerate(kpt_names):
            nn = n.lower()
            if chin_idx is None and ("tip_of_chin" in nn or nn.endswith("tip_of_chin")):
                chin_idx = i
            if glab_idx is None and "glabella" in nn:
                glab_idx = i
            # 眼睛外眦/眼睑外端，尽量选更靠外的定义
            if l_outer_idx is None and (
                "l_outer_end_of_upper_lash_line" in nn
                or "l_outer_end_of_upper_eyelid_line" in nn
            ):
                l_outer_idx = i
            if r_outer_idx is None and (
                "r_outer_end_of_upper_lash_line" in nn
                or "r_outer_end_of_upper_eyelid_line" in nn
            ):
                r_outer_idx = i

    # 阈值化 + 形态学清理 + （可选）围绕关键点的裁剪
    k = morph_kernel if morph_kernel > 0 else 0
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (max(1, k), max(1, k)))

    for t in range(T):
        fm = fused01[t]

        # 阈值
        if isinstance(threshold, str) and threshold.lower() == "otsu":
            fm8 = np.clip((fm * 255).astype(np.uint8), 0, 255)
            thr, _ = cv2.threshold(fm8, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
            m = (fm8 >= thr).astype(np.uint8)
        else:
            thr = float(threshold)
            m = (fm >= thr).astype(np.uint8)

        # 形态学（闭运算填缝 + 开运算去噪）
        if k >= 1:
            m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, kernel)
            m = cv2.morphologyEx(m, cv2.MORPH_OPEN, kernel)

        # 仅保留最大连通域
        m = _largest_cc(m, min_area=min_area)

        # 可选：用软坐标在竖直方向裁到[额头, 下巴]，水平方向裁到两眼外侧附近，进一步排除耳朵
        if (
            clip_to_face_via_landmarks
            and (chin_idx is not None)
            and (glab_idx is not None)
        ):
            y_chin, x_chin = soft_argmax2d(x[t, chin_idx])
            y_glab, x_glab = soft_argmax2d(x[t, glab_idx])
            # 估计额头上界：从眉间再向上延伸一定比例
            dy = max(1.0, abs(y_chin - y_glab))
            y_top = int(round(max(0.0, y_glab - forehead_ratio * dy)))
            y_bot = int(round(min(H - 1, y_chin + 0.05 * dy)))
            clip_top, clip_bot = min(y_top, y_bot), max(y_top, y_bot)
            # 垂直裁剪
            m[:clip_top, :] = 0
            m[clip_bot + 1 :, :] = 0

        if (
            clip_to_face_via_landmarks
            and (l_outer_idx is not None)
            and (r_outer_idx is not None)
        ):
            y_l, x_l = soft_argmax2d(x[t, l_outer_idx])
            y_r, x_r = soft_argmax2d(x[t, r_outer_idx])
            x_left = int(round(max(0.0, min(x_l, x_r) - 0.1 * abs(x_r - x_l))))
            x_right = int(round(min(W - 1, max(x_l, x_r) + 0.1 * abs(x_r - x_l))))
            m[:, :x_left] = 0
            m[:, x_right + 1 :] = 0

        masks[t] = m

    return masks[0] if was_3d else masks


# ---------- 主函数：由面部 mask 取 bbox ----------
def heatmap2bbox(
    heatmap: np.ndarray,
    kpt_names: Optional[Sequence[str]] = None,
    face_indices: Optional[Iterable[int]] = None,
    threshold: Union[float, str] = "otsu",
    morph_kernel: int = 5,
    min_area: int = 64,
    clip_to_face_via_landmarks: bool = True,
    forehead_ratio: float = 0.35,
    margin_ratio: float = 0.05,  # 对最终框的相对扩张比例
) -> Union[Tuple[int, int, int, int], np.ndarray]:
    """
    返回:
      - 单帧: (x1,y1,x2,y2)
      - 多帧: (T,4)
    """
    mask = heatmap2facemask(
        heatmap,
        kpt_names=kpt_names,
        face_indices=face_indices,
        threshold=threshold,
        morph_kernel=morph_kernel,
        min_area=min_area,
        clip_to_face_via_landmarks=clip_to_face_via_landmarks,
        forehead_ratio=forehead_ratio,
    )
    if mask.ndim == 2:
        masks = mask[None, ...]
    else:
        masks = mask

    T, H, W = masks.shape
    boxes = np.zeros((T, 4), dtype=np.int32)

    for t in range(T):
        m = masks[t]
        ys, xs = np.where(m > 0)
        if ys.size == 0:
            boxes[t] = np.array([0, 0, 0, 0], dtype=np.int32)
            continue
        y1, y2 = ys.min(), ys.max()
        x1, x2 = xs.min(), xs.max()
        # 加 margin
        h = y2 - y1 + 1
        w = x2 - x1 + 1
        dy = int(round(margin_ratio * h))
        dx = int(round(margin_ratio * w))
        y1 = max(0, y1 - dy)
        y2 = min(H - 1, y2 + dy)
        x1 = max(0, x1 - dx)
        x2 = min(W - 1, x2 + dx)
        boxes[t] = np.array([x1, y1, x2, y2], dtype=np.int32)

    return tuple(boxes[0]) if mask.ndim == 2 else boxes
