import math
import re
from typing import Dict, Iterable, List, Optional, Tuple
import cv2
from mmcv import BaseTransform
import numpy as np
import torch


from mmhug.registry import TRANSFORMS

import math
from typing import Optional, Tuple
import torch


def sapiens2mask(
    height: int,
    width: int,
    keypoint: torch.Tensor,  # (T, 308, 3): [x, y, score]，单位像素
    mask_area: str,  # "face" | "lower_face"
    mask_expand: Optional[Tuple[int, int, int, int]] = (
        0,
        0,
        0,
        0,
    ),  # (x_l, y_u, x_r, y_d)
    conf_thr: float = 0.2,  # 关键点可见性阈值
) -> torch.Tensor:
    """
    基于 Sapiens/GOLIATH 关键点，生成逐帧 **bbox 掩码**（0/1）。
    - 仅支持 bbox；不生成 pixel mask。
    - "lower_face" = 整脸框的下半部（先截半，再做 mask_expand）。
    要求 keypoint[..., :2] 为图像像素坐标（非 0~1 归一化）。
    返回: (T, H, W) 的 uint8。
    """
    assert mask_area in ("face", "lower_face")

    device = keypoint.device
    T = keypoint.shape[0]
    out = torch.zeros((T, height, width), dtype=torch.uint8, device=device)

    # —— 索引约定（见 GOLIATH_KEYPOINTS）——
    IDX_L_EYE, IDX_R_EYE = 1, 2
    IDX_L_EAR, IDX_R_EAR = 3, 4
    IDX_CHIN_TIP = 77  # tip_of_chin（在键名序列中可查）
    EYEBROW_IDXS = list(range(78, 96))  # 右眉 78–86，左眉 87–95（键名序列可查）

    # —— 可调参数 ——（如需可外露为函数参数）
    X_Q_LEFT, X_Q_RIGHT = 0.05, 0.95  # 横向鲁棒分位
    MIN_BOX_W, MIN_BOX_H = 12.0, 16.0  # 最小 bbox 尺寸
    CHIN_FALLBACK_K = 0.25  # 无下巴点时向下外推比例（按瞳距）
    EAR_PAD_FRAC = 0.02  # 耳点外再加的余量（按瞳距）
    EYE_FALLBACK_UP = 0.25  # 无眉时，上边缘=眼睛最小y - 0.25*瞳距

    def _isfinite_xy(p):
        return torch.isfinite(p[..., 0]) & torch.isfinite(p[..., 1])

    def _vis_ok(p, thr):
        return (p[..., 2] >= thr) & _isfinite_xy(p)

    def _percentile(vals, q):
        if not vals:
            return None
        s = sorted(vals)
        k = (len(s) - 1) * (q / 100.0)
        f, c = math.floor(k), math.ceil(k)
        return s[int(k)] if f == c else s[f] * (c - k) + s[c] * (k - f)

    def _interocular(kpts):
        ok_l = _vis_ok(
            kpts[IDX_L_EYE : IDX_L_EYE + 1], max(0.05, 0.5 * conf_thr)
        ).item()
        ok_r = _vis_ok(
            kpts[IDX_R_EYE : IDX_R_EYE + 1], max(0.05, 0.5 * conf_thr)
        ).item()
        if ok_l and ok_r:
            dx = float(kpts[IDX_L_EYE, 0] - kpts[IDX_R_EYE, 0])
            dy = float(kpts[IDX_L_EYE, 1] - kpts[IDX_R_EYE, 1])
            return max(8.0, (dx * dx + dy * dy) ** 0.5)
        return 32.0  # 兜底尺度

    for t in range(T):
        kpts = keypoint[t]
        if not torch.isfinite(kpts).all():
            kpts = torch.nan_to_num(kpts, nan=float("inf"))

        io = _interocular(kpts)

        # ===== 纵向：眉上缘 & 下巴尖（仅用锚点，不被嘴/其他点干扰） =====
        # 上：眉部可见点的最小 y；无眉则用“眼睛最小 y - EYE_FALLBACK_UP*瞳距”
        brow_ys = [
            float(kpts[i, 1])
            for i in EYEBROW_IDXS
            if 0 <= i < kpts.shape[0] and _vis_ok(kpts[i : i + 1], conf_thr).item()
        ]
        if brow_ys:
            y_top = min(brow_ys)
        else:
            eye_y = []
            if _vis_ok(kpts[IDX_L_EYE : IDX_L_EYE + 1], 0.05).item():
                eye_y.append(float(kpts[IDX_L_EYE, 1]))
            if _vis_ok(kpts[IDX_R_EYE : IDX_R_EYE + 1], 0.05).item():
                eye_y.append(float(kpts[IDX_R_EYE, 1]))
            if eye_y:
                y_top = min(eye_y) - EYE_FALLBACK_UP * io
            else:
                # 极端兜底：所有可见点的上 5% 分位
                ys_all = [
                    float(kpts[i, 1])
                    for i in range(kpts.shape[0])
                    if _vis_ok(kpts[i : i + 1], 0.05).item()
                ]
                if not ys_all:
                    continue
                y_top = _percentile(ys_all, 5.0)

        # 下：优先下巴尖；缺失则“面部密集段的下 98% 分位 + 瞳距外推”
        if _vis_ok(kpts[IDX_CHIN_TIP : IDX_CHIN_TIP + 1], conf_thr).item():
            y_bottom = float(kpts[IDX_CHIN_TIP, 1])
            # 若面部其他点更低，则取更大值（防嘴唇卡住）
            ys_face = [
                float(kpts[i, 1])
                for i in range(70, min(220, kpts.shape[0]))
                if _vis_ok(kpts[i : i + 1], 0.05).item()
            ]
            if ys_face:
                y_bottom = max(y_bottom, _percentile(ys_face, 98.0))
        else:
            ys_face = [
                float(kpts[i, 1])
                for i in range(70, min(220, kpts.shape[0]))
                if _vis_ok(kpts[i : i + 1], 0.05).item()
            ]
            if not ys_face:
                continue
            y_bottom = _percentile(ys_face, 98.0) + CHIN_FALLBACK_K * io

        # 尺度保护 & 裁至图像
        if (y_bottom - y_top) < MIN_BOX_H:
            mid = 0.5 * (y_top + y_bottom)
            half = 0.5 * MIN_BOX_H
            y_top, y_bottom = mid - half, mid + half
        y_top = max(0.0, y_top)
        y_bottom = min(float(height - 1), y_bottom)

        # ===== 横向：在 [y_top, y_bottom] 带内做鲁棒分位，并强制并入耳点 =====
        band_margin = 0.02 * (y_bottom - y_top)
        y_lo = max(0.0, y_top - band_margin)
        y_hi = min(float(height - 1), y_bottom + band_margin)

        xs_band = []
        vmask = _vis_ok(kpts, conf_thr)
        for i in range(kpts.shape[0]):
            if not vmask[i].item():
                continue
            yi = float(kpts[i, 1])
            if y_lo <= yi <= y_hi:
                xs_band.append(float(kpts[i, 0]))
        if len(xs_band) < 5:
            xs_band = [
                float(kpts[i, 0])
                for i in range(kpts.shape[0])
                if _isfinite_xy(kpts[i]).item()
            ]
            if not xs_band:
                continue

        x1 = _percentile(xs_band, X_Q_LEFT * 100.0)
        x2 = _percentile(xs_band, X_Q_RIGHT * 100.0)

        # 耳点并入（更宽松可见阈值 + 小外展）
        ear_pad = EAR_PAD_FRAC * io
        if _vis_ok(kpts[IDX_L_EAR : IDX_L_EAR + 1], min(0.10, conf_thr)).item():
            x1 = min(x1, float(kpts[IDX_L_EAR, 0]) - ear_pad)
        if _vis_ok(kpts[IDX_R_EAR : IDX_R_EAR + 1], min(0.10, conf_thr)).item():
            x2 = max(x2, float(kpts[IDX_R_EAR, 0]) + ear_pad)

        # 尺度保护 & 裁至图像
        if (x2 - x1) < MIN_BOX_W:
            mid = 0.5 * (x1 + x2)
            half = 0.5 * MIN_BOX_W
            x1, x2 = mid - half, mid + half
        x1 = max(0.0, x1)
        x2 = min(float(width - 1), x2)

        # ===== lower_face：取整脸下半部（先截半，再扩展） =====
        if mask_area == "lower_face":
            y_top = 0.5 * (y_top + y_bottom)

        # ===== 应用 mask_expand 并写掩码 =====
        ex_l, ex_u, ex_r, ex_d = mask_expand or (0, 0, 0, 0)
        xi1 = int(math.floor(x1)) - int(ex_l)
        yi1 = int(math.floor(y_top)) - int(ex_u)
        xi2 = int(math.ceil(x2)) + int(ex_r)
        yi2 = int(math.ceil(y_bottom)) + int(ex_d)

        xi1 = max(0, xi1)
        yi1 = max(0, yi1)
        xi2 = min(width - 1, xi2)
        yi2 = min(height - 1, yi2)

        if xi2 >= xi1 and yi2 >= yi1:
            out[t, yi1 : yi2 + 1, xi1 : xi2 + 1] = 1

    return out


@TRANSFORMS.register_module()
class SapiensKeypoint2Mask(BaseTransform):
    def __init__(
        self,
        video_key="video",
        keypoint_key="keypoint",
        mask_key="mask",
        mask_area="lower_face",
        mask_expand=(0, 0, 0, 0),
    ) -> None:
        """
        Args:
            video_key (str): The key of video in results.
            keypoint_key (str): The key of keypoint in results.
            mask_key (str): The key of mask in results.
            mask_area (str): The area of mask. If "lower_face", the mask is from
        """
        super().__init__()
        self.video_key = video_key
        self.keypoint_key = keypoint_key
        self.mask_key = mask_key

        assert mask_area in ["lower_face", "face"]
        self.mask_area = mask_area
        self.mask_expand = mask_expand

    def transform(self, results: Dict) -> Dict | Tuple[List, List] | None:
        # t 3 h w
        video = results[self.video_key]
        height, width = video.shape[2:]

        # T 308 3
        sapiens_keypoint = results[self.keypoint_key]

        mask = sapiens2mask(
            height,
            width,
            sapiens_keypoint,
            mask_area=self.mask_area,
            mask_expand=self.mask_expand,
        )
        results[self.mask_key] = mask
        return results
