import numpy as np
import json
import copy


def load_kps_conf_normalized(data, resolution):
    centers, max_width, max_height, keypoints_array, conf_array = (
        get_bbox_centers_numpy(data)
    )

    max_img_len = max(resolution)
    max_bbox_len = max(max_width, max_height)
    bbox_img_ratio = max_bbox_len / max_img_len

    normalized_centers = np.ones_like(centers)
    start_points = np.zeros_like(centers)

    pad_left = (max_img_len - resolution[0]) / 2
    pad_top = (max_img_len - resolution[1]) / 2

    normalized_centers[:, 0] = (centers[:, 0] + pad_left) / max_img_len
    normalized_centers[:, 1] = (centers[:, 1] + pad_top) / max_img_len

    start_points[:, 0] = centers[:, 0] - (max_bbox_len / 2)
    start_points[:, 1] = centers[:, 1] - (max_bbox_len / 2)

    if keypoints_array.shape[0] != start_points.shape[0]:
        raise ValueError(
            "帧数不匹配：keypoints_array 和 start_points 的第一维必须相同。"
        )
    if keypoints_array.shape[2] != 2 or start_points.shape[1] != 2:
        raise ValueError("坐标维度必须为 2。")

    # 将 start_points 广播到 (N, 1, 2) 以匹配 keypoints_array 的形状 (N, K, 2)
    start_points_reshaped = start_points[:, np.newaxis, :]

    # 逐帧计算关键点相对于中心点的偏移
    shifted_keypoints = keypoints_array - start_points_reshaped
    shifted_keypoints /= bbox_img_ratio
    final_features, final_confs = rtm2dwpose(
        shifted_keypoints, conf_array, max_bbox_len, normalized_centers
    )

    return final_features, final_confs


def get_bbox_centers_numpy(data):
    try:
        # 提取 bboxes 数据
        bboxes = data.get("bboxes", {})
        # 提取 pose_keypoints 数据
        pose_keypoints = data.get("pose_keypoints", {})
        pose_keypoint_scores = data.get("pose_keypoint_scores", {})
        if not pose_keypoints or not pose_keypoint_scores:
            print("Error: 'pose_keypoints' field missing or empty in JSON.")
            return None

        # 存储中心点的列表
        centers = []
        keypoints_list = []
        max_width = 0
        max_height = 0
        confs = []

        # 按帧顺序遍历边界框（确保帧 ID 从 0 到 N-1）
        for frame_id in sorted(pose_keypoints.keys(), key=int):
            keypoints = pose_keypoints[frame_id][0]  # 假设每帧只有一个关键点集
            conf = pose_keypoint_scores[frame_id][0]
            keypoints_list.append(keypoints)  # 填充 (x, y) 坐标
            confs.append(conf)
            bbox_list = bboxes[frame_id]
            # for bbox in bbox_list:
            bbox = bbox_list[0]
            x_min, y_min, x_max, y_max = bbox
            # 计算中心点
            x_center = (x_min + x_max) / 2
            y_center = (y_min + y_max) / 2
            centers.append([x_center, y_center])
            width = x_max - x_min
            height = y_max - y_min
            max_width = max(max_width, width)
            max_height = max(max_height, height)

        # 转换为 NumPy 数组
        return (
            np.array(centers),
            max_width,
            max_height,
            np.array(keypoints_list),
            np.array(confs),
        )

    except FileNotFoundError:
        print(f"Error: File {json_file_path} not found.")
        return None
    except json.JSONDecodeError:
        print(f"Error: Failed to parse JSON from {json_file_path}.")
        return None
    except Exception as e:
        print(f"Error: {str(e)}")
        return None


def coco_to_openpose(coco_keypoints, coco_confs):
    """
    将 COCO 17 点格式关键点转换为 OpenPose 18 点格式。
    :param coco_keypoints: numpy 数组，形状为 (N, 17, 2)，包含 COCO 格式关键点
    :return: numpy 数组，形状为 (N, 18, 2)，包含 OpenPose 格式关键点
    """
    # 映射规则
    mapping = [
        0,  # 鼻子 (Nose)
        None,  # 颈部 (Neck, 需要计算)
        6,  # 右肩 (Right Shoulder)
        8,  # 右肘 (Right Elbow)
        10,  # 右手腕 (Right Wrist)
        5,  # 左肩 (Left Shoulder)
        7,  # 左肘 (Left Elbow)
        9,  # 左手腕 (Left Wrist)
        12,  # 右臀部 (Right Hip)
        14,  # 右膝盖 (Right Knee)
        16,  # 右脚踝 (Right Ankle)
        11,  # 左臀部 (Left Hip)
        13,  # 左膝盖 (Left Knee)
        15,  # 左脚踝 (Left Ankle)
        2,  # 右眼 (Right Eye)
        1,  # 左眼 (Left Eye)
        4,  # 右耳 (Right Ear)
        3,  # 左耳 (Left Ear)
    ]
    # 左肩和右肩的索引
    left_shoulder_idx = 5
    right_shoulder_idx = 6
    # 计算颈部点
    neck = (
        coco_keypoints[:, left_shoulder_idx] + coco_keypoints[:, right_shoulder_idx]
    ) / 2

    neck_conf = (
        coco_confs[:, left_shoulder_idx] + coco_confs[:, right_shoulder_idx]
    ) / 2
    # 创建 OpenPose 格式关键点数组
    if coco_keypoints.ndim == 2:
        openpose_keypoints = np.zeros((1, 18, 2))
        openpose_confs = np.zeros((1, 18))
    else:
        openpose_keypoints = np.zeros((coco_keypoints.shape[0], 18, 2))
        openpose_confs = np.zeros((coco_keypoints.shape[0], 18))
    for i, coco_idx in enumerate(mapping):
        if coco_idx is None:
            openpose_keypoints[:, i] = neck
            openpose_confs[:, i] = neck_conf
        else:
            openpose_keypoints[:, i] = coco_keypoints[:, coco_idx]
            openpose_confs[:, i] = coco_confs[:, coco_idx]
    return openpose_keypoints, openpose_confs


def coco_foot_to_openpose(coco_keypoints, coco_confs):
    """
    将 COCO WholeBody 的脚部关键点 (6 点) 转换为 OpenPose 的脚部关键点 (2 点)。
    :param coco_keypoints: numpy 数组，形状为 (N, 23, 2)，包含 COCO WholeBody 关键点
    :return: numpy 数组，形状为 (N, 2, 2)，包含 OpenPose 脚部关键点
    """
    # 提取左脚后跟 (COCO 索引 17) 和右脚后跟 (COCO 索引 18)
    left_heel, left_heel_conf = coco_keypoints[:, 2], coco_confs[:, 2]  # 左脚后跟
    right_heel, right_heel_conf = coco_keypoints[:, 5], coco_confs[:, 5]  # 右脚后跟
    # 构造 OpenPose 脚部关键点数组
    openpose_foot = np.stack([left_heel, right_heel], axis=1)
    openpose_foot_confs = np.stack([left_heel_conf, right_heel_conf], axis=1)
    return openpose_foot, openpose_foot_confs


def rtm2dwpose(rtm_keypoints, rtm_confs, max_len, normalized_centers):
    rtm_keypoints /= max_len
    # rtm_keypoints -= 0.5
    body, body_conf = coco_to_openpose(rtm_keypoints[:, :17], rtm_confs[:, :17])
    foot, foot_conf = coco_foot_to_openpose(
        rtm_keypoints[:, 17:23], rtm_confs[:, 17:23]
    )
    face, face_conf = rtm_keypoints[:, 23:91], rtm_confs[:, 23:91]
    hands, hands_conf = rtm_keypoints[:, 91:], rtm_confs[:, 91:]
    normalized_centers_reshape = normalized_centers[:, np.newaxis, :]
    features = np.concatenate(
        [normalized_centers_reshape, body, foot, face, hands], axis=1
    )
    confs = np.concatenate([body_conf, foot_conf, face_conf, hands_conf], axis=1) / 10.0
    return features, confs
