import numpy as np
import torch
import joblib

# from .transforms import rescale_transform, rescale_transform_back, resize_to_square
from .transforms import preprocess_frames, preprocess_frames_joblib


# TODO: Remove this function
def auto_pose_decompose(features, return_dict=False):
    num_face_kps = 68
    assert features.ndim == 2
    nframes = features.shape[0]
    ndim = features.shape[-1]
    pose = {}

    # output with confidence scores
    bodies_score = np.ones((nframes, 1, 24))
    hands_score = np.ones((nframes, 2, 21))
    faces_score = np.ones((nframes, 1, 68))

    # TODO: Hardcode part detection
    # Bodies
    if ndim in [48, 268]:
        num_body_kps = 24
        pose["bodies"] = {
            "candidate": features[:, : num_body_kps * 2].reshape(
                nframes, num_body_kps, 2
            ),
            "subset": np.arange(num_body_kps).repeat(nframes),
        }
    elif ndim in [266]:
        num_body_kps = 23
        pose["bodies"] = {
            "candidate": features[:, : num_body_kps * 2].reshape(
                nframes, num_body_kps, 2
            ),
            "subset": np.arange(num_body_kps).repeat(nframes),
        }
    elif ndim in [268 + 9, 268 + 12]:
        num_body_kps = 24
        body_transform_params = features[:, :3]
        body_keypoints = features[:, 3 : num_body_kps * 2 + 3].reshape(
            nframes, num_body_kps, 2
        )
        body_candidate = rescale_transform_back(body_keypoints, body_transform_params)
        pose["bodies"] = {
            "candidate": body_candidate,
            "subset": np.arange(num_body_kps).repeat(nframes),
        }
    # with full-body confidence scores
    elif ndim in [402]:
        num_body_kps = 24
        reshaped_features = features[:, : num_body_kps * 3].reshape(
            nframes, num_body_kps, 3
        )
        pose["bodies"] = {
            "candidate": reshaped_features[:, :, :2],  # xy
            "subset": np.arange(num_body_kps).repeat(nframes),
        }
        bodies_score = reshaped_features[:, :, 2]  # score

    # Faces
    if ndim in [266, 268]:
        pose["faces"] = feature2face(features)
    elif ndim in [402]:
        pose["faces"] = feature2face(features)
        faces_score = feature2facescore(features)
    elif ndim in [268 + 9, 268 + 12]:
        face_transform_params = features[:, 3 + num_body_kps * 2 : 6 + num_body_kps * 2]
        face_keypoints = features[
            :, 6 + num_body_kps * 2 : 6 + num_body_kps * 2 + num_face_kps * 2
        ].reshape(nframes, num_face_kps, 2)
        face_candidate = rescale_transform_back(face_keypoints, face_transform_params)
        pose["faces"] = face_candidate

    # Hands
    if ndim in [266, 268]:
        pose["hands"] = feature2hands(features)
    elif ndim in [402]:
        pose["hands"] = feature2hands(features)
        hands_score = feature2handsscore(features)
    elif ndim in [268 + 12]:
        num_hand_kps = 21
        left_hand_transform_params = features[
            :,
            6 + num_body_kps * 2 + num_face_kps * 2 : 9
            + num_body_kps * 2
            + num_face_kps * 2,
        ]
        left_hand_keypoints = features[
            :,
            9 + num_body_kps * 2 + num_face_kps * 2 : 9
            + num_body_kps * 2
            + num_face_kps * 2
            + num_hand_kps * 2,
        ].reshape(nframes, num_hand_kps, 2)
        right_hand_transform_params = features[
            :,
            9 + num_body_kps * 2 + num_face_kps * 2 + num_hand_kps * 2 : 12
            + num_body_kps * 2
            + num_face_kps * 2
            + num_hand_kps * 2,
        ]
        right_hand_keypoints = features[
            :, 12 + num_body_kps * 2 + num_face_kps * 2 + num_hand_kps * 2 :
        ].reshape(nframes, num_hand_kps, 2)

        left_hand_candidate = rescale_transform_back(
            left_hand_keypoints, left_hand_transform_params
        )
        right_hand_candidate = rescale_transform_back(
            right_hand_keypoints, right_hand_transform_params
        )
        pose["hands"] = np.stack([left_hand_candidate, right_hand_candidate], axis=1)

    if return_dict:
        pose_dict = [
            {
                "hands": pose["hands"][i],
                "faces": pose["faces"][i][None],
                "bodies": {
                    "candidate": pose["bodies"]["candidate"][:, :20][i],
                    "subset": pose["bodies"]["subset"].reshape(nframes, -1)[:, :20][
                        i : i + 1
                    ],
                    # "subset": np.arange(20)[None],
                },
                "bodies_score": bodies_score[i : i + 1, :],
                "faces_score": faces_score[i : i + 1, :],
                "hands_score": hands_score[i, :, :],
            }
            for i in range(nframes)
        ]
        return (pose, pose_dict)

    return pose


def pose_decompose(features, category="whole_body"):
    if category == "whole_body":
        return features.reshape(features.shape[0], -1, 2)
    elif category == "body":
        return feature2body(features)
    elif category == "foot":
        return feature2foot(features)
    elif category == "face":
        return feature2face(features)
    elif category == "hands":
        return feature2hands(features)
    else:
        raise ValueError(f"Unknown category: {category}")


def feature2body(features):
    assert features.ndim == 2
    nframes = features.shape[0]
    ndim = features.shape[-1]
    if ndim in [48, 268]:
        num_body_kps = 24
        return features[:, : num_body_kps * 2].reshape(nframes, num_body_kps, 2)
    elif ndim in [266]:
        num_body_kps = 23
        return features[:, : num_body_kps * 2].reshape(nframes, num_body_kps, 2)


def feature2foot(features):
    assert features.ndim == 2
    nframes = features.shape[0]
    ndim = features.shape[-1]
    if ndim == 268:
        num_body_kps = 24
    elif ndim == 266:
        num_body_kps = 23
    else:
        raise ValueError(f"Unknown feature dimension: {ndim}")

    bodies = features[:, : num_body_kps * 2].reshape(nframes, num_body_kps, 2)
    return bodies[:, -2:]


def feature2face(features):
    assert features.ndim == 2
    nframes = features.shape[0]
    ndim = features.shape[-1]
    wscore = False
    if ndim == 268:
        num_body_kps = 24
    elif ndim == 266:
        num_body_kps = 23
    elif ndim == 402:
        num_body_kps = 24
        wscore = True
    else:
        raise ValueError(f"Unknown feature dimension: {ndim}")

    if wscore:
        rs = features[
            :, num_body_kps * 3 : num_body_kps * 3 + num_face_kps * 3
        ].reshape(nframes, num_face_kps, 3)
        return rs[:, :, :2]
    else:
        return features[
            :, num_body_kps * 2 : num_body_kps * 2 + num_face_kps * 2
        ].reshape(nframes, num_face_kps, 2)


def feature2facescore(features):
    assert features.ndim == 2
    nframes = features.shape[0]
    ndim = features.shape[-1]
    if ndim == 402:
        num_body_kps = 24
    else:
        raise ValueError(f"Uncorrect feature dimension: {ndim}")

    rs = features[:, num_body_kps * 3 : num_body_kps * 3 + num_face_kps * 3].reshape(
        nframes, num_face_kps, 3
    )
    return rs[:, :, 2]


def feature2hands(features):
    assert features.ndim == 2
    nframes = features.shape[0]
    ndim = features.shape[-1]
    wscore = False
    if ndim == 268:
        num_body_kps = 24
    elif ndim == 266:
        num_body_kps = 23
    elif ndim == 402:
        num_body_kps = 24
        wscore = True
    else:
        raise ValueError(f"Unknown feature dimension: {ndim}")

    if wscore:
        rs = features[:, num_body_kps * 3 + num_face_kps * 3 :].reshape(
            nframes, 2, 21, 3
        )
        return rs[:, :, :, :2]
    else:
        return features[:, num_body_kps * 2 + num_face_kps * 2 :].reshape(
            nframes, 2, 21, 2
        )


def feature2handsscore(features):
    assert features.ndim == 2
    nframes = features.shape[0]
    ndim = features.shape[-1]
    if ndim == 402:
        num_body_kps = 24
    else:
        raise ValueError(f"Uncorrect feature dimension: {ndim}")

    rs = features[:, num_body_kps * 3 + num_face_kps * 3 :].reshape(nframes, 2, 21, 3)
    return rs[:, :, :, 2]


def compose_feature(body, face, hands):
    body = body.reshape(body.shape[0], -1)
    face = face.reshape(face.shape[0], -1)
    hands = hands.reshape(hands.shape[0], -1)
    if isinstance(body, np.ndarray):
        return np.concatenate([body, face, hands], axis=1)
    elif isinstance(body, torch.Tensor):
        return torch.cat([body, face, hands], dim=1)
    else:
        raise ValueError(f"Unknown type: {type(body)}")


def decompose_feature(features):
    body = feature2body(features)
    face = feature2face(features)
    hands = feature2hands(features)
    return body, face, hands


def process_keypoints(
    candidate, num_kps, max_num_candidates, fill_value=0, conf_scores=None
):
    """
    Process and normalize keypoints, adding padding or confidence scores as needed.

    Parameters:
        candidate (np.ndarray): Keypoint data for a specific body part.
        num_kps (int): Expected number of keypoints.
        max_num_candidates (int): Maximum number of candidates to pad to.
        fill_value (float): Value used to pad missing keypoints.
        conf_scores (np.ndarray): Optional confidence scores to concatenate.

    Returns:
        np.ndarray: Processed and padded keypoints.
    """
    if candidate.ndim == 2:
        candidate = candidate[None]

    # Pad keypoints to the required number
    if candidate.shape[2] < num_kps:
        padding = np.full(
            (
                *candidate.shape[:-2],
                num_kps - candidate.shape[-2],
                candidate.shape[-1],
            ),
            fill_value,
        )
        candidate = np.concatenate([candidate, padding], axis=1)

    # Concatenate confidence scores if specified
    if conf_scores is not None:
        # (frame, num_kps) -> (frame, num_kps, 1)
        conf_scores = np.expand_dims(conf_scores[: candidate.shape[0], :num_kps], -1)
        candidate = np.concatenate([candidate, conf_scores], axis=-1)

    # Repeat to match max_num_candidates
    candidate = np.repeat(candidate, max_num_candidates, axis=0)[:max_num_candidates]

    return candidate.reshape(max_num_candidates, -1)  # Resize frames to square


def dict2feature(
    data,
    num_body_kps=24,
    num_face_kps=68,
    num_hand_kps=21,
    body_parts=("body", "face", "hand"),
    cat_confidence=False,
    abl_square=False,
    augmentations=True,
    augmentation_params={
        "scale_range": 0.5,
        "rotation_range": 5,
        "translation_range": 0.02,
    },
):
    """
    Process pose data from a list of frames and extract features for specified body parts.

    Parameters:
        data (list): List of dictionaries containing pose data for each frame.
        num_body_kps (int): Number of body keypoints to process.
        num_face_kps (int): Number of face keypoints to process.
        num_hand_kps (int): Number of hand keypoints to process.
        body_parts (tuple): Body parts to process (e.g., "body", "face", "hand").
        rescale (bool): Whether to apply rescaling to keypoints.
        cat_confidence (bool): Whether to concatenate confidence scores with keypoints.
        abl_square (bool): Whether to resize frames to square.
        augmentations (bool): Whether to apply augmentations.
        augmentation_params (dict): Parameters for augmentations. Default: {"scale_range": 0.05, "rotation_range": 10, "translation_range": 0.02}.

    Returns:
        dict: Dictionary containing processed features, masks, scores, and metadata.
    """
    if not data:
        raise ValueError("Input data is empty")

    # Initialize containers for processed data
    features, missing_masks, bodies_subset, keypoint_scores = [], [], [], []

    # Ensure each frame has a 'score' field
    for frame in data:
        if "score" not in frame:
            frame["score"] = np.zeros((frame["bodies"]["subset"].shape[0], 1))

    # Determine the maximum number of candidates across all frames
    max_num_candidates = max(
        max(
            frame["bodies"]["subset"].shape[0],
            frame["score"].shape[0],
            frame["faces"].shape[0],
        )
        for frame in data
    )

    if max_num_candidates > 1:
        raise ValueError("Multiple candidates per frame are not supported")

    data = preprocess_frames(
        data,
        num_body_kps=num_body_kps,
        num_face_kps=num_face_kps,
        num_hand_kps=num_hand_kps,
        nomalize2square=abl_square,
        augmentations=augmentations,
        **augmentation_params,
    )

    # Process each frame
    for frame in data:
        frame_features = []

        # Body keypoints
        if "body" in body_parts:
            body_keypoints = process_keypoints(
                frame["bodies"]["candidate"][:, :num_body_kps],
                num_body_kps,
                max_num_candidates,
                fill_value=0,
                conf_scores=frame["score"] if cat_confidence else None,
            )
            frame_features.append(body_keypoints)

        # Face keypoints
        if "face" in body_parts:
            face_keypoints = process_keypoints(
                frame["faces"],
                num_face_kps,
                max_num_candidates,
                fill_value=-1,
                conf_scores=frame["score"][
                    :, num_body_kps : num_body_kps + num_face_kps
                ]
                if cat_confidence
                else None,
            )
            frame_features.append(face_keypoints)

        # Hand keypoints
        if "hand" in body_parts:
            hand_keypoints = process_keypoints(
                frame["hands"],
                num_hand_kps,
                max_num_candidates,
                fill_value=-1,
                conf_scores=frame["score"][:, -2 * num_hand_kps :].reshape(
                    frame["hands"].shape[0], 2, num_hand_kps
                )
                if cat_confidence
                else None,
            )
            frame_features.append(hand_keypoints)

        # Concatenate features for the frame
        combined_features = np.concatenate(frame_features, axis=1)
        mask = combined_features == -1  # Mark missing keypoints
        features.append(combined_features)
        missing_masks.append(mask)
        body_subset = np.repeat(frame["bodies"]["subset"], max_num_candidates, axis=0)[
            :max_num_candidates
        ]
        bodies_subset.append(body_subset)
        score = np.repeat(frame["score"], max_num_candidates, axis=0)[
            :max_num_candidates
        ]
        keypoint_scores.append(score)

    # Convert lists to NumPy arrays
    return {
        "feature": np.array(features),
        "length": len(features),
        "bodies_subset": np.array(bodies_subset),
        "ignore_mask": np.array(missing_masks),
        "keypoint_scores": np.array(keypoint_scores),
    }


def feature2dict(
    features: np.ndarray | torch.Tensor,
    num_body_kps=24,
    num_face_kps=68,
    num_hand_kps=21,
    body_parts=("body", "face", "hand"),
    cat_confidence=False,
    confidence_threshold=0.37,
):
    """
    Convert processed feature data back to the original dictionary structure.

    Parameters:
        features (np.ndarray | torch.Tensor): Processed pose features.
        num_body_kps (int): Number of body keypoints.
        num_face_kps (int): Number of face keypoints.
        num_hand_kps (int): Number of hand keypoints.
        body_parts (tuple): Body parts to process (e.g., "body", "face", "hand").
        cat_confidence (bool): Whether confidence scores are concatenated in the features.
        confidence_threshold (float): Confidence threshold below which joints are set to -1.

    Returns:
        list: List of dictionaries representing the original pose data format.
    """
    assert features.ndim == 2
    num_frames = features.shape[0]
    dims = 3 if cat_confidence else 2  # Each keypoint's dimensionality

    reconstructed_frames = []

    for frame_idx in range(num_frames):
        current_feature = features[frame_idx]
        frame_data = {}
        offset = 0

        # Process body keypoints
        if "body" in body_parts:
            body_length = num_body_kps * dims
            body_keypoints = current_feature[offset : offset + body_length].reshape(
                -1, num_body_kps, dims
            )
            offset += body_length

            # Apply confidence threshold if required
            if cat_confidence:
                confidence = body_keypoints[..., -1]
                valid_joint_mask = confidence >= confidence_threshold
                body_keypoints[~valid_joint_mask, :2] = -1
            else:
                valid_joint_mask = np.all(body_keypoints[..., :2] > -0.5, axis=-1)

            subset = np.where(
                valid_joint_mask,  # Ensure all joints are valid
                np.arange(body_keypoints.shape[-2]),  # Sequential indices
                -1,  # Invalid subset
            )

            frame_data["bodies"] = {
                "candidate": body_keypoints[..., :2],  # Remove confidence if present
                "subset": subset,
            }

        # Process face keypoints
        if "face" in body_parts:
            face_length = num_face_kps * dims
            face_keypoints = current_feature[offset : offset + face_length].reshape(
                -1, num_face_kps, dims
            )
            offset += face_length

            # Apply confidence threshold if required
            if cat_confidence:
                confidence = face_keypoints[..., -1]
                face_keypoints[confidence < confidence_threshold, :2] = -1

            frame_data["faces"] = face_keypoints[..., :2].reshape(-1, num_face_kps, 2)

        # Process hand keypoints
        if "hand" in body_parts:
            hand_length = 2 * num_hand_kps * dims
            hand_keypoints = current_feature[offset : offset + hand_length].reshape(
                -1, 2, num_hand_kps, dims
            )
            offset += hand_length

            # Apply confidence threshold if required
            if cat_confidence:
                confidence = hand_keypoints[..., -1]
                hand_keypoints[confidence < confidence_threshold, :2] = -1

            frame_data["hands"] = hand_keypoints[..., :2].reshape(
                -1, 2, num_hand_kps, 2
            )

        # Append processed frame data
        reconstructed_frames.append(frame_data)

    return reconstructed_frames


def dict2feature_joblib(
    data,
    H,
    W,
    num_body_kps=24,
    num_face_kps=68,
    num_hand_kps=21,
    body_parts=("body", "face", "hand"),
    cat_confidence=False,
    abl_square=False,
    augmentations=True,
    augmentation_params={
        "scale_range": 0.5,
        "rotation_range": 5,
        "translation_range": 0.02,
    },
):
    """
    Process pose data from a list of frames and extract features for specified body parts.

    Parameters:
        data (list): List of dictionaries containing pose data for each frame.
        num_body_kps (int): Number of body keypoints to process.
        num_face_kps (int): Number of face keypoints to process.
        num_hand_kps (int): Number of hand keypoints to process.
        body_parts (tuple): Body parts to process (e.g., "body", "face", "hand").
        rescale (bool): Whether to apply rescaling to keypoints.
        cat_confidence (bool): Whether to concatenate confidence scores with keypoints.
        abl_square (bool): Whether to resize frames to square.
        augmentations (bool): Whether to apply augmentations.
        augmentation_params (dict): Parameters for augmentations. Default: {"scale_range": 0.05, "rotation_range": 10, "translation_range": 0.02}.

    Returns:
        dict: Dictionary containing processed features, masks, scores, and metadata.
    """
    if not data:
        raise ValueError("Input data is empty")

    # Initialize containers for processed data
    features, missing_masks, bodies_subset, keypoint_scores = [], [], [], []

    # Ensure each frame has a 'score' field
    for frame in data:
        if frame["keypoints"].ndim == 2:
            frame["keypoints"] = frame["keypoints"][None]

        if frame["keypoints_scores"].ndim == 1:
            frame["keypoints_scores"] = frame["keypoints_scores"][None]

    # Determine the maximum number of candidates across all frames

    max_num_candidates = max(
        max(
            frame["keypoints"].shape[0],
            frame["keypoints_scores"].shape[0],
        )
        for frame in data
    )

    # if max_num_candidates > 1:
    #     raise ValueError("Multiple candidates per frame are not supported")

    data = preprocess_frames_joblib(
        data,
        H,
        W,
        num_body_kps=num_body_kps,
        num_face_kps=num_face_kps,
        num_hand_kps=num_hand_kps,
        nomalize2square=abl_square,
        augmentations=augmentations,
        **augmentation_params,
    )

    # Process each frame
    for frame in data:
        frame_features = []

        # Body keypoints
        if "body" in body_parts:
            body_keypoints = process_keypoints(
                frame["bodies"]["candidate"][:, :num_body_kps],
                num_body_kps,
                max_num_candidates,
                fill_value=0,
                conf_scores=frame["score"] if cat_confidence else None,
            )
            frame_features.append(body_keypoints)

        # Face keypoints
        if "face" in body_parts:
            face_keypoints = process_keypoints(
                frame["faces"],
                num_face_kps,
                max_num_candidates,
                fill_value=-1,
                conf_scores=frame["score"][
                    :, num_body_kps : num_body_kps + num_face_kps
                ]
                if cat_confidence
                else None,
            )
            frame_features.append(face_keypoints)

        # Hand keypoints
        if "hand" in body_parts:
            hand_keypoints = process_keypoints(
                frame["hands"],
                num_hand_kps,
                max_num_candidates,
                fill_value=-1,
                conf_scores=frame["score"][:, -2 * num_hand_kps :].reshape(
                    frame["hands"].shape[0], 2, num_hand_kps
                )
                if cat_confidence
                else None,
            )
            frame_features.append(hand_keypoints)

        # Concatenate features for the frame
        combined_features = np.concatenate(frame_features, axis=1)
        mask = combined_features == -1  # Mark missing keypoints
        features.append(combined_features)
        missing_masks.append(mask)
        body_subset = np.repeat(frame["bodies"]["subset"], max_num_candidates, axis=0)[
            :max_num_candidates
        ]
        bodies_subset.append(body_subset)
        score = np.repeat(frame["score"], max_num_candidates, axis=0)[
            :max_num_candidates
        ]
        keypoint_scores.append(score)

    # Convert lists to NumPy arrays
    return {
        "feature": np.array(features),
        "length": len(features),
        "bodies_subset": np.array(bodies_subset),
        "ignore_mask": np.array(missing_masks),
        "keypoint_scores": np.array(keypoint_scores),
    }
