import torch
import numpy as np


def rescale_transform(keypoints, scale_factor):
    """
    Rescale the keypoints based on the scale factor.
    Args:
        keypoints (np.array): The keypoints to rescale with [N, J, 2].
        scale_factor (float): The scale factor to apply.
    """

    valid_points_x = keypoints[:, :, 0][keypoints[:, :, 0] != -1]
    valid_points_y = keypoints[:, :, 1][keypoints[:, :, 1] != -1]
    min_x = np.min(valid_points_x) if valid_points_x.size > 0 else 0
    max_x = np.max(valid_points_x) if valid_points_x.size > 0 else 0
    min_y = np.min(valid_points_y) if valid_points_y.size > 0 else 0
    max_y = np.max(valid_points_y) if valid_points_y.size > 0 else 0

    width = max_x - min_x
    height = max_y - min_y
    if np.abs(width - 0) < 1e-6 or np.abs(height - 0) < 1e-6:
        scale_factor = 1
        translation_x = 0
        translation_y = 0
    else:
        scale_factor = scale_factor * min(1 / width, 1 / height)
        translation_x = 0.5 - (min_x + max_x) / 2
        translation_y = 0.5 - (min_y + max_y) / 2

    keypoints[:, :, 0][keypoints[:, :, 0] != -1] = (
        keypoints[:, :, 0][keypoints[:, :, 0] != -1] + translation_x
    ) * scale_factor + 0.5
    keypoints[:, :, 1][keypoints[:, :, 0] != -1] = (
        keypoints[:, :, 1][keypoints[:, :, 0] != -1] + translation_y
    ) * scale_factor + 0.5

    # Cat parameters
    transform_params = np.array([scale_factor, translation_x, translation_y])

    return keypoints, transform_params


def rescale_transform_back(keypoints, transform_params):
    """
    Rescale the keypoints based on the scale factor.
    Args:
        keypoints (np.array): The keypoints to rescale with [B, T, J, 2].
        transform_params (np.array): The scale factor and translation to apply.
    """

    scale_factor = np.expand_dims(transform_params[..., 0], -1)
    translation_x = np.expand_dims(transform_params[..., 1], -1)
    translation_y = np.expand_dims(transform_params[..., 2], -1)

    # Apply transformations separately for x and y coordinates
    mask = keypoints[..., 0] != -1

    keypoints[..., 0][mask] = (
        (keypoints[..., 0] - 0.5) / scale_factor - translation_x
    )[mask]
    mask = keypoints[..., 1] != -1
    keypoints[..., 1][mask] = (
        (keypoints[..., 1] - 0.5) / scale_factor - translation_y
    )[mask]

    return keypoints


def preprocess_frames(
    frames,
    num_body_kps=24,
    num_face_kps=68,
    num_hand_kps=21,
    nomalize2square=True,
    augmentations=True,
    scale_range=0.05,
    rotation_range=5,
    translation_range=0.02,
):
    """
    Resize frames to square by calculating bounding box of keypoints and padding the shorter side.

    Parameters:
        frames (list): List of frames where each frame contains keypoints (body, face, hands).
        num_body_kps (int): Number of body keypoints per frame.
        num_face_kps (int): Number of face keypoints per frame.
        num_hand_kps (int): Number of hand keypoints per frame.

    Returns:
        list: Frames with resized and centered keypoints while keeping the original structure intact.
    """
    # Calculate max height (max_Y) and max width (max_X) across all frames
    min_x, max_x, min_y, max_y = np.inf, -np.inf, np.inf, -np.inf

    # Traverse all frames to get the extreme coordinates of keypoints
    for frame in frames:
        keypoints = []

        # Collect body, face, and hand keypoints
        if "bodies" in frame and "candidate" in frame["bodies"]:
            if frame["bodies"]["candidate"].ndim == 2:
                frame["bodies"]["candidate"] = frame["bodies"]["candidate"][None]
            keypoints.append(frame["bodies"]["candidate"][..., :num_body_kps, :])
        if "faces" in frame and frame["faces"].size > 0:
            keypoints.append(frame["faces"][..., :num_face_kps, :])
        if "hands" in frame and frame["hands"].size > 0:
            keypoints.append(frame["hands"][..., :num_hand_kps, :])

        # Iterate through all keypoints to find the min/max, excluding -1 values
        for kp_set in keypoints:
            # Extract X and Y coordinates
            x_coords = kp_set[..., 0]
            y_coords = kp_set[..., 1]

            # Create masks for valid coordinates (exclude -1)
            valid_x_mask = x_coords != -1
            valid_y_mask = y_coords != -1

            # Apply masks to get valid coordinates
            valid_x = x_coords[valid_x_mask]
            valid_y = y_coords[valid_y_mask]

            # Update min and max values if valid coordinates exist
            if valid_x.size > 0:
                min_x = min(min_x, valid_x.min())
                max_x = max(max_x, valid_x.max())
            if valid_y.size > 0:
                min_y = min(min_y, valid_y.min())
                max_y = max(max_y, valid_y.max())

    # Calculate max_H and max_W
    max_H = max_y - min_y
    max_W = max_x - min_x
    square_size = max(max_H, max_W)
    # square_size = square_size * (1 + scale_range)

    # If inf values are not updated, return the original frames
    if square_size < 1e-6:
        raise ValueError("No valid keypoints found in the frames.")

    # Precompute scaling factors to ensure consistency across all keypoints
    offset_x = (square_size - max_W) / 2
    offset_y = (square_size - max_H) / 2

    # Process each frame to rescale and center keypoints
    for frame in frames:
        if "bodies" in frame and "candidate" in frame["bodies"]:
            body_keypoints = frame["bodies"]["candidate"][..., :num_body_kps, :]
            if nomalize2square:
                frame["bodies"]["candidate"] = rescale_and_center_keypoints(
                    body_keypoints,
                    min_x,
                    min_y,
                    square_size,
                    offset_x,
                    offset_y,
                )

        if "faces" in frame and frame["faces"].size > 0:
            face_keypoints = frame["faces"][..., :num_face_kps, :]
            if nomalize2square:
                frame["faces"] = rescale_and_center_keypoints(
                    face_keypoints,
                    min_x,
                    min_y,
                    square_size,
                    offset_x,
                    offset_y,
                )

        if "hands" in frame and frame["hands"].size > 0:
            hand_keypoints = frame["hands"][..., :num_hand_kps, :]
            if nomalize2square:
                frame["hands"] = rescale_and_center_keypoints(
                    hand_keypoints,
                    min_x,
                    min_y,
                    square_size,
                    offset_x,
                    offset_y,
                )

    # Apply augmentations to keypoints
    if augmentations:
        frames = augment_keypoints(
            frames,
            scale_range,
            rotation_range,
            translation_range,
        )

    return frames


def preprocess_frames_joblib(
    frames,
    H,
    W,
    original_kps=(17, 6, 68, 21, 21),
    num_body_kps=24,
    num_face_kps=68,
    num_hand_kps=21,
    nomalize2square=True,
    augmentations=True,
    scale_range=0.05,
    rotation_range=5,
    translation_range=0.02,
):
    """
    Resize frames to square by calculating bounding box of keypoints and padding the shorter side.

    Parameters:
        frames (list): List of frames where each frame contains keypoints (body, face, hands).
        num_body_kps (int): Number of body keypoints per frame.
        num_face_kps (int): Number of face keypoints per frame.
        num_hand_kps (int): Number of hand keypoints per frame.

    Returns:
        list: Frames with resized and centered keypoints while keeping the original structure intact.
    """
    # Calculate max height (max_Y) and max width (max_X) across all frames
    min_x, max_x, min_y, max_y = np.inf, -np.inf, np.inf, -np.inf

    def pad_keypoints(kps, target_kps):
        if kps.shape[1] == target_kps:
            return kps
        elif kps.shape[1] > target_kps:
            return kps[..., :target_kps, :]
        else:
            kps = np.concatenate(
                [
                    kps,
                    np.ones(
                        (
                            body_kps.shape[0],
                            target_kps - body_kps.shape[1],
                            2,
                        )
                    )
                    * -1,
                ],
                axis=-2,
            )
            return kps

    def pad_scores(scores, target_kps):
        if scores.shape[1] == target_kps:
            return scores
        elif scores.shape[1] > target_kps:
            return scores[..., :target_kps]
        else:
            scores = np.concatenate(
                [
                    scores,
                    np.zeros(
                        (
                            body_kps.shape[0],
                            target_kps - body_kps.shape[1],
                        )
                    ),
                ],
                axis=-1,
            )
            return scores

    # Traverse all frames to get the extreme coordinates of keypoints
    for frame in frames:
        keypoints = []
        scores = []

        frame["keypoints"][..., 0] = frame["keypoints"][..., 0] / W
        frame["keypoints"][..., 1] = frame["keypoints"][..., 1] / H

        # Collect body, face, and hand keypoints
        # (c, 2, 133, 2)
        # TODO: coco wholebody 133 -> body, face, hand
        body_kps = frame["keypoints"][..., : original_kps[0], :]
        body_kps = coco_to_openpose(body_kps)
        foot_kps = frame["keypoints"][
            ..., original_kps[0] : original_kps[0] + original_kps[1], :
        ]
        foot_kps = coco_foot_to_openpose(foot_kps)
        body_kps = np.concatenate([body_kps, foot_kps], axis=-2)
        body_kps = pad_keypoints(body_kps, num_body_kps)
        frame["bodies"] = {}
        frame["bodies"]["subset"] = np.arange(num_body_kps)[None]
        frame["bodies"]["candidate"] = body_kps
        keypoints.append(body_kps)

        body_scores = frame["keypoints_scores"][..., : original_kps[0]]
        body_scores = coco_to_openpose(body_scores)
        foot_scores = frame["keypoints_scores"][
            ..., original_kps[0] : original_kps[0] + original_kps[1]
        ]
        body_scores = np.concatenate([body_scores, foot_scores], axis=-1)
        body_scores = pad_scores(body_scores, num_body_kps)
        scores.append(body_scores)

        # Face keypoints
        start_idx = original_kps[0] + original_kps[1]
        original_face_kps = original_kps[2]
        face_kps = frame["keypoints"][..., start_idx : start_idx + original_face_kps, :]
        face_kps = pad_keypoints(face_kps, num_face_kps)
        frame["faces"] = face_kps
        keypoints.append(face_kps)

        face_scores = frame["keypoints_scores"][
            ..., start_idx : start_idx + original_face_kps
        ]
        face_scores = pad_scores(face_scores, num_face_kps)
        scores.append(face_scores)

        # Left hand keypoints
        start_idx = start_idx + original_face_kps
        original_left_hand_kps = original_kps[3]
        left_hand_kps = frame["keypoints"][
            ..., start_idx : start_idx + original_left_hand_kps, :
        ]
        left_hand_kps = pad_keypoints(left_hand_kps, num_hand_kps)
        left_hand_scores = frame["keypoints_scores"][
            ..., start_idx : start_idx + original_left_hand_kps
        ]

        scores.append(left_hand_scores)

        original_right_hand_kps = original_kps[4]
        start_idx = start_idx + original_left_hand_kps
        right_hand_kps = frame["keypoints"][
            ..., start_idx : start_idx + original_right_hand_kps, :
        ]
        right_hand_kps = pad_keypoints(right_hand_kps, num_hand_kps)
        right_hand_scores = frame["keypoints_scores"][
            ..., start_idx : start_idx + original_right_hand_kps
        ]

        scores.append(right_hand_scores)

        # Stack left and right hand keypoints
        hand_kps = np.stack([left_hand_kps, right_hand_kps], axis=1)
        frame["hands"] = hand_kps
        keypoints.append(hand_kps)

        # Cat scores
        frame["score"] = np.concatenate(scores, axis=-1)

        # Iterate through all keypoints to find the min/max, excluding -1 values
        for kp_set in keypoints:
            # Extract X and Y coordinates
            x_coords = kp_set[..., 0]
            y_coords = kp_set[..., 1]

            # Create masks for valid coordinates (exclude -1)
            valid_x_mask = x_coords != -1
            valid_y_mask = y_coords != -1

            # Apply masks to get valid coordinates
            valid_x = x_coords[valid_x_mask]
            valid_y = y_coords[valid_y_mask]

            # Update min and max values if valid coordinates exist
            if valid_x.size > 0:
                min_x = min(min_x, valid_x.min())
                max_x = max(max_x, valid_x.max())
            if valid_y.size > 0:
                min_y = min(min_y, valid_y.min())
                max_y = max(max_y, valid_y.max())

    # Calculate max_H and max_W
    max_H = max_y - min_y
    max_W = max_x - min_x
    square_size = max(max_H, max_W)
    # square_size = square_size * (1 + scale_range)

    # If inf values are not updated, return the original frames
    if square_size < 1e-6:
        raise ValueError("No valid keypoints found in the frames.")

    # Precompute scaling factors to ensure consistency across all keypoints
    offset_x = (square_size - max_W) / 2
    offset_y = (square_size - max_H) / 2

    # Process each frame to rescale and center keypoints
    for frame in frames:
        if "bodies" in frame and "candidate" in frame["bodies"]:
            body_keypoints = frame["bodies"]["candidate"][..., :num_body_kps, :]
            if nomalize2square:
                frame["bodies"]["candidate"] = rescale_and_center_keypoints(
                    body_keypoints,
                    min_x,
                    min_y,
                    square_size,
                    offset_x,
                    offset_y,
                )

        if "faces" in frame and frame["faces"].size > 0:
            face_keypoints = frame["faces"][..., :num_face_kps, :]
            if nomalize2square:
                frame["faces"] = rescale_and_center_keypoints(
                    face_keypoints,
                    min_x,
                    min_y,
                    square_size,
                    offset_x,
                    offset_y,
                )

        if "hands" in frame and frame["hands"].size > 0:
            hand_keypoints = frame["hands"][..., :num_hand_kps, :]
            if nomalize2square:
                frame["hands"] = rescale_and_center_keypoints(
                    hand_keypoints,
                    min_x,
                    min_y,
                    square_size,
                    offset_x,
                    offset_y,
                )

    # Apply augmentations to keypoints
    if augmentations:
        frames = augment_keypoints(
            frames,
            scale_range,
            rotation_range,
            translation_range,
        )

    return frames


def rescale_and_center_keypoints(
    keypoints, min_x, min_y, square_size, padding_x, padding_y
):
    """
    Rescale keypoints and center them within a square frame with padding.

    Parameters:
        keypoints (np.ndarray): Keypoints to rescale and center.
        min_x (float): Minimum X coordinate across all frames.
        min_y (float): Minimum Y coordinate across all frames.
        square_size (int): The size of the square frame.
        padding_x (float): Padding to apply on the X-axis.
        padding_y (float): Padding to apply on the Y-axis.

    Returns:
        np.ndarray: Rescaled and centered keypoints.
    """
    # Only process valid keypoints
    valid_x_mask = keypoints[..., 0] != -1
    valid_y_mask = keypoints[..., 1] != -1
    # Rescale keypoints
    keypoints[..., 0] = (keypoints[..., 0] - min_x + padding_x) / square_size
    keypoints[..., 1] = (keypoints[..., 1] - min_y + padding_y) / square_size
    # Restore invalid keypoints
    keypoints[..., 0][~valid_x_mask] = -1
    keypoints[..., 1][~valid_y_mask] = -1

    return keypoints


def augment_keypoints(
    keypoints, scale_range=0.1, rotation_range=15, translation_range=0.05
):
    """
    Augment keypoints by applying random scaling, rotation, and translation.

    Parameters:
        keypoints (list[np.ndarray]): List of keypoints to augment.
        scale_range (float): Range for random scaling.
        rotation_range (float): Range for random rotation.
        translation_range (float): Range for random translation.

    Returns:
        np.ndarray: Augmented keypoints.
    """
    # Check for the existence of keypoints parts
    parts = []
    if "bodies" in keypoints[0] and "candidate" in keypoints[0]["bodies"]:
        parts.append(("bodies", "candidate"))
    if "faces" in keypoints[0] and keypoints[0]["faces"].size > 0:
        parts.append(("faces", None))
    if "hands" in keypoints[0] and keypoints[0]["hands"].size > 0:
        parts.append(("hands", None))

    # Set random augmentation parameters
    scale_factor = 1 + np.random.uniform(-scale_range, scale_range)
    rotation_range = np.deg2rad(rotation_range)
    rotation_angle = np.random.uniform(-rotation_range, rotation_range)
    translation = np.random.uniform(-translation_range, translation_range, 2)

    # Define functions
    def rescale_keypoints(keypoints, scale_factor):
        # Conly process valid keypoints
        valid_x_mask = keypoints[..., 0] != -1
        valid_y_mask = keypoints[..., 1] != -1
        keypoints[..., :2] -= 0.5
        keypoints[..., :2] = keypoints[..., :2] * scale_factor
        keypoints[..., :2] += 0.5
        # Restore invalid keypoints
        keypoints[..., 0][~valid_x_mask] = -1
        keypoints[..., 1][~valid_y_mask] = -1
        return keypoints

    def rotate_keypoints(keypoints, angle):
        # Conly process valid keypoints
        valid_x_mask = keypoints[..., 0] != -1
        valid_y_mask = keypoints[..., 1] != -1
        keypoints[..., :2] -= 0.5
        rotation_matrix = np.array(
            [
                [np.cos(angle), -np.sin(angle)],
                [np.sin(angle), np.cos(angle)],
            ]
        )
        keypoints[..., :2] = np.dot(keypoints[..., :2], rotation_matrix)
        keypoints[..., :2] += 0.5
        # Restore invalid keypoints
        keypoints[..., 0][~valid_x_mask] = -1
        keypoints[..., 1][~valid_y_mask] = -1
        return keypoints

    def translate_keypoints(keypoints, translation):
        # Conly process valid keypoints
        valid_x_mask = keypoints[..., 0] != -1
        valid_y_mask = keypoints[..., 1] != -1
        keypoints[..., :2] += translation
        # Restore invalid keypoints
        keypoints[..., 0][~valid_x_mask] = -1
        keypoints[..., 1][~valid_y_mask] = -1
        return keypoints

    # Apply augmentation to each part
    for frame in range(len(keypoints)):
        for part, subpart in parts:
            if subpart:
                keypoints[frame][part][subpart] = rescale_keypoints(
                    keypoints[frame][part][subpart], scale_factor
                )
                keypoints[frame][part][subpart] = rotate_keypoints(
                    keypoints[frame][part][subpart], rotation_angle
                )
                keypoints[frame][part][subpart] = translate_keypoints(
                    keypoints[frame][part][subpart], translation
                )
            else:
                keypoints[frame][part] = rescale_keypoints(
                    keypoints[frame][part], scale_factor
                )
                keypoints[frame][part] = rotate_keypoints(
                    keypoints[frame][part], rotation_angle
                )
                keypoints[frame][part] = translate_keypoints(
                    keypoints[frame][part], translation
                )

    return keypoints


def coco_to_openpose(coco_keypoints):
    """
    将 COCO 17 点格式关键点转换为 OpenPose 18 点格式。
    :param coco_keypoints: numpy 数组，形状为 (N, 17, 2)，包含 COCO 格式关键点
    :return: numpy 数组，形状为 (N, 18, 2)，包含 OpenPose 格式关键点
    """
    # 映射规则
    mapping = [
        0,  # 鼻子 (Nose)
        None,  # 颈部 (Neck, 需要计算)
        6,  # 右肩 (Right Shoulder)
        8,  # 右肘 (Right Elbow)
        10,  # 右手腕 (Right Wrist)
        5,  # 左肩 (Left Shoulder)
        7,  # 左肘 (Left Elbow)
        9,  # 左手腕 (Left Wrist)
        12,  # 右臀部 (Right Hip)
        14,  # 右膝盖 (Right Knee)
        16,  # 右脚踝 (Right Ankle)
        11,  # 左臀部 (Left Hip)
        13,  # 左膝盖 (Left Knee)
        15,  # 左脚踝 (Left Ankle)
        2,  # 右眼 (Right Eye)
        1,  # 左眼 (Left Eye)
        4,  # 右耳 (Right Ear)
        3,  # 左耳 (Left Ear)
    ]

    # 左肩和右肩的索引
    left_shoulder_idx = 5
    right_shoulder_idx = 6

    # 计算颈部点
    neck = (
        coco_keypoints[:, left_shoulder_idx] + coco_keypoints[:, right_shoulder_idx]
    ) / 2

    # 创建 OpenPose 格式关键点数组
    if coco_keypoints.ndim == 2:
        openpose_keypoints = np.zeros((coco_keypoints.shape[0], 18))
    else:
        openpose_keypoints = np.zeros((coco_keypoints.shape[0], 18, 2))

    for i, coco_idx in enumerate(mapping):
        if coco_idx is None:
            openpose_keypoints[:, i] = neck
        else:
            openpose_keypoints[:, i] = coco_keypoints[:, coco_idx]

    return openpose_keypoints


def coco_foot_to_openpose(coco_keypoints):
    """
    将 COCO WholeBody 的脚部关键点 (6 点) 转换为 OpenPose 的脚部关键点 (2 点)。
    :param coco_keypoints: numpy 数组，形状为 (N, 23, 2)，包含 COCO WholeBody 关键点
    :return: numpy 数组，形状为 (N, 2, 2)，包含 OpenPose 脚部关键点
    """
    # 提取左脚后跟 (COCO 索引 17) 和右脚后跟 (COCO 索引 18)
    left_heel = coco_keypoints[:, 2]  # 左脚后跟
    right_heel = coco_keypoints[:, 5]  # 右脚后跟

    # 构造 OpenPose 脚部关键点数组
    openpose_foot = np.stack([left_heel, right_heel], axis=1)

    return openpose_foot
