import numpy as np
import cv2
from typing import Tuple, List, Optional

from tqdm import tqdm

# MediaPipe Face Detection (BlazeFace)
import mediapipe as mp
from mediapipe.tasks import python as mp_python
from mediapipe.tasks.python import vision as mp_vision
from scipy.signal import medfilt

mp_fd = mp.solutions.face_detection


def _choose_interp(src_wh: Tuple[int, int], dst_wh: Tuple[int, int]) -> int:
    sw, sh = src_wh
    dw, dh = dst_wh
    # 缩小优先 INTER_AREA，放大用 INTER_LINEAR
    if dw < sw or dh < sh:
        return cv2.INTER_AREA
    else:
        return cv2.INTER_LINEAR


def _expand_square_box(x, y, w, h, img_w, img_h, scale=1.0):
    """
    将 (x,y,w,h) 扩为以中心为基的正方形，并按 scale 放大；返回裁剪时需要的整型边界（含越界范围，后续用pad处理）
    """
    cx = x + w * 0.5
    cy = y + h * 0.5
    side = max(w, h) * float(scale)
    left = int(np.floor(cx - side * 0.5))
    top = int(np.floor(cy - side * 0.5))
    right = int(np.ceil(cx + side * 0.5))
    bottom = int(np.ceil(cy + side * 0.5))
    return left, top, right, bottom


def _crop_with_pad(
    img: np.ndarray, left: int, top: int, right: int, bottom: int
) -> np.ndarray:
    """
    允许越界的裁剪：越界部分用边缘像素填充（edge pad）
    """
    h, w = img.shape[:2]
    pad_left = max(0, -left)
    pad_top = max(0, -top)
    pad_right = max(0, right - w)
    pad_bottom = max(0, bottom - h)

    # 先clip到图像内再pad
    cl = max(0, left)
    ct = max(0, top)
    cr = min(w, right)
    cb = min(h, bottom)

    crop = img[ct:cb, cl:cr]
    if any(v > 0 for v in (pad_left, pad_top, pad_right, pad_bottom)):
        crop = cv2.copyMakeBorder(
            crop,
            pad_top,
            pad_bottom,
            pad_left,
            pad_right,
            borderType=cv2.BORDER_REPLICATE,
        )
    return crop


def detect_and_resize_faces_rgb(
    video: np.ndarray,
    target_hw: Tuple[int, int] = (96, 96),
    min_conf: float = 0.5,
    model_selection: int = 0,  # 该参数在 Tasks API 不再使用，保留占位以兼容原签名
    expand_scale: float = 1.0,
    *,
    model_path: str = "checkpoints/mediapipe/detector.tflite",  # 必填：MediaPipe Tasks 模型文件 face_detector.task 的路径
    use_gpu: bool = True,  # 仅 Ubuntu 的 Python SDK 支持 GPU delegate
) -> Tuple[np.ndarray, List[Optional[Tuple[int, int, int, int]]]]:
    """
    输入:
        video: (T,H,W,3) RGB, uint8
    输出:
        crops: (T, target_h, target_w, 3) RGB, uint8
        bboxes: 每帧的人脸框 (x,y,w,h)，像素坐标；未检出则为 None（有回退策略）
    """
    assert video.ndim == 4 and video.shape[-1] == 3, "video must be (T,H,W,3)"
    assert video.dtype == np.uint8, "video dtype must be uint8 (0..255)"
    T, H, W, _ = video.shape
    target_h, target_w = target_hw

    crops = np.empty((T, target_h, target_w, 3), dtype=np.uint8)
    bboxes: List[Optional[Tuple[int, int, int, int]]] = [None] * T
    last_box_xywh: Optional[Tuple[int, int, int, int]] = None

    # ---- 构建 MediaPipe Tasks FaceDetector（GPU 优先，失败回退 CPU） ----
    # Tasks API 的 FaceDetector 接口与 legacy solutions 不同：返回的是**像素**单位的 BoundingBox，
    # 并可设置 running_mode=IMAGE/VIDEO/LIVE_STREAM。此处用 IMAGE 模式，逐帧检测无需时间戳。:contentReference[oaicite:1]{index=1}
    delegate = (
        mp_python.BaseOptions.Delegate.GPU
        if use_gpu
        else mp_python.BaseOptions.Delegate.CPU
    )
    base_options = mp_python.BaseOptions(model_asset_path=model_path, delegate=delegate)
    options = mp_vision.FaceDetectorOptions(
        base_options=base_options,
        running_mode=mp_vision.RunningMode.IMAGE,
        min_detection_confidence=min_conf,
        min_suppression_threshold=0.3,
    )
    try:
        detector = mp_vision.FaceDetector.create_from_options(options)
    except Exception:
        # 不支持 GPU 时回退 CPU（Ubuntu 以外平台/驱动缺失等）:contentReference[oaicite:2]{index=2}
        base_options = mp_python.BaseOptions(
            model_asset_path=model_path, delegate=mp_python.BaseOptions.Delegate.CPU
        )
        options = mp_vision.FaceDetectorOptions(
            base_options=base_options,
            running_mode=mp_vision.RunningMode.IMAGE,
            min_detection_confidence=min_conf,
            min_suppression_threshold=0.3,
        )
        detector = mp_vision.FaceDetector.create_from_options(options)

    try:
        for t in range(T):
            frame = video[t]  # (H,W,3) RGB, uint8

            # 转为 MediaPipe Image（SRGB）并做检测
            mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
            result = detector.detect(mp_image)  # IMAGE 模式无需时间戳

            # 选择分数最高的人脸（Tasks 返回 detection.categories[0].score）:contentReference[oaicite:3]{index=3}
            best = None
            if result.detections:
                best = max(
                    result.detections,
                    key=lambda d: (d.categories[0].score if d.categories else 0.0),
                )

            if best and best.categories and best.categories[0].score >= min_conf:
                bb = (
                    best.bounding_box
                )  # mp.tasks.components.containers.BoundingBox（像素）:contentReference[oaicite:4]{index=4}
                x, y, w0, h0 = (
                    int(bb.origin_x),
                    int(bb.origin_y),
                    int(bb.width),
                    int(bb.height),
                )
                # 边界保护
                x = max(-W, min(W, x))
                y = max(-H, min(H, y))
                w0 = max(1, min(W, w0))
                h0 = max(1, min(H, h0))
                last_box_xywh = (x, y, w0, h0)

            # 未检出：回退到上一帧；再没有则取中心 0.6×min(H,W) 的正方形
            if last_box_xywh is None:
                side = int(min(W, H) * 0.6)
                cx, cy = W // 2, H // 2
                x = max(0, cx - side // 2)
                y = max(0, cy - side // 2)
                last_box_xywh = (x, y, side, side)

            x, y, w0, h0 = last_box_xywh

            # 扩为正方形并裁剪（越界用边缘复制填充），再 resize 到目标尺寸
            l, tt, r, b = _expand_square_box(x, y, w0, h0, W, H, scale=expand_scale)
            face_crop = _crop_with_pad(frame, l, tt, r, b)
            interp = _choose_interp(
                src_wh=(face_crop.shape[1], face_crop.shape[0]),
                dst_wh=(target_w, target_h),
            )
            face_resized = cv2.resize(
                face_crop, (target_w, target_h), interpolation=interp
            )

            crops[t] = face_resized
            bboxes[t] = (x, y, w0, h0)
    finally:
        # 释放 detector（with 语法糖不便插 try/except，这里手动 close）
        if "detector" in locals():
            detector.close()

    return crops, bboxes


def detect_and_resize_faces_for_syncnet(
    video: np.ndarray,
    target_hw: Tuple[int, int] = (96, 96),
    min_conf: float = 0.5,
    expand_scale: float = 0.40,  # 同 run_pipeline.py 的 crop_scale≈0.40
    *,
    model_path: str = "checkpoints/mediapipe/detector.tflite",
    use_gpu: bool = True,
) -> Tuple[np.ndarray, List[Optional[Tuple[int, int, int, int]]]]:

    assert video.ndim == 4 and video.shape[-1] == 3, "video must be (T,H,W,3)"
    assert video.dtype == np.uint8, "video dtype must be uint8 (0..255)"
    T, H, W, _ = video.shape
    target_h, target_w = target_hw

    crops = np.empty((T, target_h, target_w, 3), dtype=np.uint8)
    bboxes: List[Optional[Tuple[int, int, int, int]]] = [None] * T
    last_box_xywh: Optional[Tuple[int, int, int, int]] = None

    # --- 构建 MediaPipe Tasks FaceDetector（GPU 优先，失败回退 CPU） ---
    delegate = (
        mp_python.BaseOptions.Delegate.GPU
        if use_gpu
        else mp_python.BaseOptions.Delegate.CPU
    )
    base_options = mp_python.BaseOptions(model_asset_path=model_path, delegate=delegate)
    options = mp_vision.FaceDetectorOptions(
        base_options=base_options,
        running_mode=mp_vision.RunningMode.IMAGE,
        min_detection_confidence=min_conf,
        min_suppression_threshold=0.3,
    )
    try:
        detector = mp_vision.FaceDetector.create_from_options(options)
    except Exception:
        base_options = mp_python.BaseOptions(
            model_asset_path=model_path, delegate=mp_python.BaseOptions.Delegate.CPU
        )
        options = mp_vision.FaceDetectorOptions(
            base_options=base_options,
            running_mode=mp_vision.RunningMode.IMAGE,
            min_detection_confidence=min_conf,
            min_suppression_threshold=0.3,
        )
        detector = mp_vision.FaceDetector.create_from_options(options)

    # -------- Pass 1：逐帧检测，收集中心与尺度 --------
    cxs = np.zeros(T, dtype=np.float32)
    cys = np.zeros(T, dtype=np.float32)
    ss = np.zeros(T, dtype=np.float32)

    try:
        for t in range(T):
            frame = video[t]  # (H,W,3) RGB, uint8
            mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
            result = detector.detect(mp_image)

            best = None
            if result.detections:
                best = max(
                    result.detections,
                    key=lambda d: (d.categories[0].score if d.categories else 0.0),
                )

            if best and best.categories and best.categories[0].score >= min_conf:
                bb = best.bounding_box  # 像素坐标 (origin_x, origin_y, width, height)
                x, y, w0, h0 = (
                    int(bb.origin_x),
                    int(bb.origin_y),
                    int(bb.width),
                    int(bb.height),
                )
                # 边界保护
                x = max(-W, min(W, x))
                y = max(-H, min(H, y))
                w0 = max(1, min(W, w0))
                h0 = max(1, min(H, h0))
                last_box_xywh = (x, y, w0, h0)

            if last_box_xywh is None:
                # 第1帧也没检出：用图像中心的正方形
                side = int(min(W, H) * 0.6)
                cx, cy = W // 2, H // 2
                last_box_xywh = (
                    max(0, cx - side // 2),
                    max(0, cy - side // 2),
                    side,
                    side,
                )

            x, y, w0, h0 = last_box_xywh
            bboxes[t] = (x, y, w0, h0)

            # 记录中心与尺度（与 run_pipeline 一致）
            cxs[t] = x + 0.5 * w0
            cys[t] = y + 0.5 * h0
            ss[t] = 0.5 * max(w0, h0)

        # -------- 平滑：对 s / cx / cy 做中值滤波（核=13；不足时取≤T的最近奇数） --------
        def odd_kernel(k: int, T: int) -> int:
            k = min(k, T)
            if k < 1:
                return 1
            return k if (k % 2 == 1) else (k - 1)

        k = odd_kernel(13, T)
        if k > 1:
            ss = medfilt(ss, kernel_size=k)
            cxs = medfilt(cxs, kernel_size=k)
            cys = medfilt(cys, kernel_size=k)

        # -------- Pass 2：按平滑后的中心/尺度做“110 灰 padding + 向下偏置裁剪”并 resize --------
        cs = float(expand_scale)
        for t in range(T):
            frame = video[t]
            s = float(ss[t])
            cx = float(cxs[t])
            cy = float(cys[t])

            bsi = int(max(1, round(s * (1 + 2 * cs))))
            padded = np.pad(
                frame,
                pad_width=((bsi, bsi), (bsi, bsi), (0, 0)),
                mode="constant",
                constant_values=110,
            )  # 与 run_pipeline 一致的灰色填充

            mx = cx + bsi
            my = cy + bsi

            top = int(round(my - s))
            bottom = int(round(my + s * (1 + 2 * cs)))
            left = int(round(mx - s * (1 + cs)))
            right = int(round(mx + s * (1 + cs)))

            Hpad, Wpad = padded.shape[:2]
            top = max(0, min(Hpad, top))
            bottom = max(0, min(Hpad, bottom))
            left = max(0, min(Wpad, left))
            right = max(0, min(Wpad, right))

            crop = padded[top:bottom, left:right]
            interp = _choose_interp(
                src_wh=(crop.shape[1], crop.shape[0]), dst_wh=(target_w, target_h)
            )
            crops[t] = cv2.resize(crop, (target_w, target_h), interpolation=interp)

    finally:
        if "detector" in locals():
            detector.close()

    return crops, bboxes


# MediaPipe（新版 Tasks 或经典 solutions 均可；这里用经典 solutions，安装: pip install mediapipe)
import mediapipe as mp

mp_face_mesh = mp.solutions.face_mesh
mp_face_det = mp.solutions.face_detection


def _landmarks_to_pixels(landmarks, img_w, img_h):
    """MediaPipe 归一化关键点 -> 像素坐标的 (N, 2) int32 数组。"""
    pts = []
    for lm in landmarks:
        x = int(round(lm.x * img_w))
        y = int(round(lm.y * img_h))
        pts.append([x, y])
    return np.asarray(pts, dtype=np.int32)


def _mp_bbox_to_xyxy(detection, img_w, img_h):
    """MediaPipe Detection bbox (relative) -> (x1, y1, x2, y2) 像素框。"""
    bbox = detection.location_data.relative_bounding_box
    x1 = int(round(bbox.xmin * img_w))
    y1 = int(round(bbox.ymin * img_h))
    x2 = int(round((bbox.xmin + bbox.width) * img_w))
    y2 = int(round((bbox.ymin + bbox.height) * img_h))
    # 规范化与边界裁剪
    x1, y1 = max(0, x1), max(0, y1)
    x2, y2 = max(x1 + 1, x2), max(y1 + 1, y2)
    return (x1, y1, x2, y2)


def get_landmark_and_bbox_mediapipe(
    frames, upperbondrange=0, coord_placeholder=(0, 0, 1, 1)
) -> Tuple[List[Tuple[int, int, int, int]], List[np.ndarray]]:
    """
    用 MediaPipe 替代 “fa + mmpose” 的 landmark+bbox 提取，保持与原函数一致的输出与行为：
      - 返回: coords_list(每帧 bbox)、frames(读取的帧列表)
      - upperbondrange 语义、打印信息、非法 landmark 框回退到检测框 的逻辑不变
      - 模型精度差异忽略
    """
    batch_size_fa = 1
    batches = [
        frames[i : i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)
    ]

    coords_list = []
    if upperbondrange != 0:
        print(
            "get key_landmark and face bounding boxes with the bbox_shift:",
            upperbondrange,
        )
    else:
        print("get key_landmark and face bounding boxes with the default value")

    average_range_minus, average_range_plus = [], []

    # 初始化 MediaPipe 模型（单实例、循环复用）
    # FaceMesh: 468点；FaceDetection: bbox 兜底
    face_mesh = mp_face_mesh.FaceMesh(
        static_image_mode=True,
        max_num_faces=1,
        refine_landmarks=False,  # 可按需 True
        min_detection_confidence=0.5,
    )
    face_det = mp_face_det.FaceDetection(
        model_selection=1, min_detection_confidence=0.5  # 1: 5m 以内
    )

    for fb in tqdm(batches):
        # 原代码对 batch_size_fa=1，只取第 0 张
        img = np.asarray(fb)[0]
        if img is None:
            coords_list += [coord_placeholder]
            continue

        h, w = img.shape[:2]
        # MediaPipe 期望 RGB
        rgb = img

        # 1) Face Mesh 关键点（若无则为空）
        mesh_res = face_mesh.process(rgb)
        if not mesh_res.multi_face_landmarks:
            # 回退：直接用检测框；若无检测也返回占位
            det_res = face_det.process(rgb)
            if det_res.detections:
                coords_list += [_mp_bbox_to_xyxy(det_res.detections[0], w, h)]
            else:
                coords_list += [coord_placeholder]
            continue

        # 取第一张脸（与原逻辑一致）
        face_landmarks = mesh_res.multi_face_landmarks[0].landmark
        face_land_mark = _landmarks_to_pixels(face_landmarks, w, h)  # (468, 2) int32
        # 与原代码保持 dtype/下游一致
        face_land_mark = face_land_mark.astype(np.int32)

        # 2) 兜底人脸检测框（当 landmark 框不合法时用）
        det_res = face_det.process(rgb)
        fallback_bbox = None
        if det_res.detections:
            fallback_bbox = _mp_bbox_to_xyxy(det_res.detections[0], w, h)

        # 3) 依据 landmark 生成“基于上半/下半对称”的人脸框（复刻原逻辑）
        # 原逻辑使用 face_land_mark[28], [29], [30] 的垂直距离估计上/下范围；
        # 这里用 468 点整体的纵向统计来近似：选“中轴”行的代表点为 half_face_coord，
        # 用顶部 y_min 与底部 y_max 的差生成对称上边界。
        ys = face_land_mark[:, 1]
        xs = face_land_mark[:, 0]
        y_min, y_max = int(ys.min()), int(ys.max())
        x_min, x_max = int(xs.min()), int(xs.max())

        # 近似“29/28/30”三点的垂直距离：使用分位点/相邻中值来估计
        # 这里取 y 的 0.5 分位附近点作为 half_face（鼻梁/中脸附近）
        median_y = int(np.median(ys))
        # 用 y 的 IQR 两侧近似“上下相邻”间距，避免个别点噪声
        q45 = int(np.percentile(ys, 45))
        q55 = int(np.percentile(ys, 55))
        range_minus = q55 - median_y  # 类似 (30-29)[1]
        range_plus = median_y - q45  # 类似 (29-28)[1]
        average_range_minus.append(range_minus)
        average_range_plus.append(range_plus)

        half_face_coord_y = median_y
        if upperbondrange != 0:
            # 与原注释语义一致：+ 向下（偏29），- 向上（偏28）
            half_face_coord_y = half_face_coord_y + upperbondrange

        half_face_dist = y_max - half_face_coord_y
        min_upper_bond = 0
        upper_bond = max(min_upper_bond, half_face_coord_y - half_face_dist)

        # 由 landmark 得到候选 bbox
        f_landmark = (x_min, int(upper_bond), x_max, y_max)
        x1, y1, x2, y2 = f_landmark

        # 4) 与原逻辑一致：若 landmark 框不合格，则回退到检测框
        if (y2 - y1) <= 0 or (x2 - x1) <= 0 or x1 < 0:
            if fallback_bbox is not None:
                coords_list += [fallback_bbox]
                w0, h0 = (
                    fallback_bbox[2] - fallback_bbox[0],
                    fallback_bbox[3] - fallback_bbox[1],
                )
                print("error bbox (fallback to detector):", fallback_bbox)
            else:
                coords_list += [coord_placeholder]
                print(
                    "error bbox and no detector result; use placeholder:",
                    coord_placeholder,
                )
        else:
            coords_list += [f_landmark]

    print(
        "********************************************bbox_shift parameter adjustment**********************************************************"
    )
    if len(average_range_minus) == 0:
        am, ap = 0, 0
    else:
        am = int(sum(average_range_minus) / len(average_range_minus))
        ap = int(sum(average_range_plus) / len(average_range_plus))
    print(
        f"Total frame:「{len(frames)}」 Manually adjust range : [ -{am}~{ap} ] , the current value: {upperbondrange}"
    )
    print(
        "*************************************************************************************************************************************"
    )
    return coords_list, frames


# ---------------- main: 用 decord 读取，imageio 写 mp4 ----------------
if __name__ == "__main__":
    import os
    import imageio
    from decord import VideoReader, cpu

    in_path = "data/hallo3/hallo3_training_data/videos_cropped_new/fe0eb399d3372546b6437401d707551b.mp4"
    out_path = "output.mp4"

    assert os.path.isfile(in_path), f"Input not found: {in_path}"

    # 1) 读视频（RGB，逐帧）
    # Decord: VideoReader & get_avg_fps()。:contentReference[oaicite:4]{index=4}
    vr = VideoReader(in_path, ctx=cpu(0))
    fps = float(vr.get_avg_fps())  # 平均帧率
    T = len(vr)

    frames = []
    frames.reserve(T) if hasattr(frames, "reserve") else None  # 可选加速
    for i in range(T):
        fr = vr[i].asnumpy()  # (H,W,3), RGB, uint8
        frames.append(fr)
    video_np = np.stack(frames, axis=0)  # (T,H,W,3), RGB, uint8

    # 2) 人脸裁剪并缩放到 (96,96)
    crops, _ = detect_and_resize_faces_for_syncnet(
        video_np, target_hw=(224, 224), min_conf=0.5
    )
    # crops = crops[:, 48:]
    # 3) 写回 MP4（H.264）
    # imageio + ffmpeg: 指定 fps、codec。:contentReference[oaicite:5]{index=5}
    writer = imageio.get_writer(
        out_path,
        format="FFMPEG",
        fps=fps,
        codec="libx264",
        macro_block_size=None,  # 允许任意分辨率如 96x96
        output_params=["-pix_fmt", "yuv420p"],
    )
    for f in crops:
        writer.append_data(f)  # (96,96,3) RGB uint8
    writer.close()

    print(f"[OK] Wrote {T} frames to {out_path} @ {fps:.2f} FPS")
