import numpy as np


class FaceAbsentError(ValueError):
    pass


def _face_indices_ex_ear_from_names(names):
    """基于官方名称表，返回“面部(不含耳朵)”的索引列表。names: 长度应为 K 的小写字符串列表"""
    FACE_TOKENS = (
        "glabella",
        "brow",
        "eyebrow",
        "eyelid",
        "lash",
        "eye",
        "iris",
        "pupil",
        "canthus",
        "nose",
        "nasal",
        "subnasale",
        "alar",
        "columella",
        "philtrum",
        "lip",
        "labial",
        "mouth",
        "cupid",
        "cheek",
        "malar",
        "chin",
        "menton",
        "gonion",
        "forehead",
        "trichion",
        "face",
    )
    EAR_TOKENS = (
        "ear",
        "tragus",
        "antitragus",
        "intertragic",
        "helix",
        "antihelix",
        "concha",
        "scapha",
        "cymba",
        "lobule",
        "lobe",
        "crus_of_helix",
    )
    is_face = [any(tok in n for tok in FACE_TOKENS) for n in names]
    is_ear = [any(tok in n for tok in EAR_TOKENS) for n in names]
    face_no_ear = [i for i, (f, e) in enumerate(zip(is_face, is_ear)) if (f and not e)]
    # 基础 0..2（鼻/眼）若未被关键词命中，兜底加入
    for i in (0, 1, 2):
        if i < len(names) and i not in face_no_ear:
            face_no_ear.append(i)
    return sorted(set(face_no_ear))


def _fallback_face_indices_ex_ear_308(K=308):
    """
    对 GOLIATH-308 的稳健回退划分（无需名称表）：
    - 面部：0..2 以及 70..279
    - 排除耳朵：3,4 以及 >=280 的耳廓密集点
    """
    idxs = []
    if K >= 1:
        idxs += [0, 1, 2]  # 基础面部点（鼻/眼等）
    if K >= 280:
        idxs += list(range(70, 280))
    else:
        # 非严格 308 时，尽量取后段为面部的近似划分
        start = min(70, K)
        end = min(280, K)
        if end > start:
            idxs += list(range(start, end))
    # 去除耳基点 3,4
    idxs = [i for i in idxs if i not in (3, 4) and i < K]
    return sorted(set(idxs))


def _get_face_indices_ex_ear(K):
    """
    优先基于官方名称表（若可用）筛选；否则使用 308 点的回退分段。
    """
    names = None
    try:
        # 你的环境若已把官方 Space 中的 classes_and_palettes.py 放入 PYTHONPATH
        from .kpt_classes_and_palettes import GOLIATH_KEYPOINTS as GK

        if isinstance(GK, (list, tuple)) and len(GK) >= K:
            names = [str(x).lower() for x in GK[:K]]
    except Exception:
        names = None
    if names is not None:
        face_idx = _face_indices_ex_ear_from_names(names)
        if len(face_idx) > 0:
            return face_idx
    # 回退
    return _fallback_face_indices_ex_ear_308(K)


def check_face(
    keypoint,
    score_thr: float = 0.3,
    min_face_points: int = 20,
) -> bool:
    """
    检查 T×K×3 的 GOLIATH 关键点序列是否每一帧都“包含面部（不含耳朵）”。
    包含面的判定：该帧“面部(不含耳)”索引集合中，置信度 >= score_thr 的点数 ≥ min_face_points。
    若任一帧不满足，抛出 FaceAbsentError；全部满足则返回 True。

    参数
    ----
    keypoint : np.ndarray or torch.Tensor, 形状 (T,K,3) 且末维是 (x,y,score)
    score_thr : 置信度阈值（默认 0.3，参考 MMPose 可视化的 kpt_thr 默认值）
    min_face_points : 每帧至少命中的面部关键点数量阈值（默认 20）

    返回
    ----
    True (若所有帧均满足)

    异常
    ----
    FaceAbsentError: 若存在任一帧不满足“包含面部”的判定
    """
    # 转为 numpy
    if hasattr(keypoint, "detach"):  # torch.Tensor
        keypoint = keypoint.detach().cpu().numpy()
    else:
        keypoint = np.asarray(keypoint)
    assert (
        keypoint.ndim == 3 and keypoint.shape[-1] >= 3
    ), f"Expect (T,K,3), got {keypoint.shape}"

    T, K, _ = keypoint.shape

    # 取得面部(不含耳)的索引集合
    face_idx = _get_face_indices_ex_ear(K)
    if len(face_idx) == 0:
        raise RuntimeError(
            "Cannot determine face indices (excluding ears); please provide name list or adjust fallback."
        )

    # 取分数并做数值消毒（NaN/Inf -> 0）
    scores = keypoint[..., 2]
    scores = np.nan_to_num(scores, nan=0.0, posinf=0.0, neginf=0.0)

    # 逐帧统计：面部有效点数
    face_scores = scores[:, face_idx]  # (T, |F|)
    valid_mask = face_scores >= float(score_thr)  # (T, |F|)
    valid_cnt = valid_mask.sum(axis=1)  # (T,)

    bad = np.where(valid_cnt < int(min_face_points))[0]
    if bad.size > 0:
        # 给出前若干坏帧的细节，便于排错
        preview = ", ".join([f"{i}(cnt={int(valid_cnt[i])})" for i in bad[:10]])
        raise FaceAbsentError(
            f"Face not detected on {bad.size}/{T} frames. "
            f"Examples: [{preview}] (thr={score_thr}, min_face_points={min_face_points})"
        )

    return True
