import os

import numpy as np
import torch

from .wholebody import Wholebody

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DWposeDetectorAligned:
    def __init__(self, device='cpu'):
        print(f'DWposeDetectorAligned, {device=}')
        self.pose_estimation = Wholebody()

    def release_memory(self):
        if hasattr(self, 'pose_estimation'):
            del self.pose_estimation
            import gc; gc.collect()

    def __call__(self, oriImg):
        oriImg = oriImg.copy()
        H, W, C = oriImg.shape
        with torch.no_grad():
            candidate, score = self.pose_estimation(oriImg)
            nums, _, locs = candidate.shape
            candidate[..., 0] /= float(W)
            candidate[..., 1] /= float(H)
            body = candidate[:, :18].copy()
            body = body.reshape(nums * 18, locs)
            subset = score[:, :18].copy()
            for i in range(len(subset)):
                for j in range(len(subset[i])):
                    if subset[i][j] > 0.3:
                        subset[i][j] = int(18 * i + j)
                    else:
                        subset[i][j] = -1


            faces = candidate[:, 24:92]

            hands = candidate[:, 92:113]
            hands = np.vstack([hands, candidate[:, 113:]])

            faces_score = score[:, 24:92]
            hands_score = np.vstack([score[:, 92:113], score[:, 113:]])

            bodies = dict(candidate=body, subset=subset, score=score[:, :18])
            pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score)

            return pose


dwpose_detector_aligned = DWposeDetectorAligned(device=device)


import torch
import os
from .wholebody import Wholebody
import numpy as np


_dwpose_detector_aligned_cpu = None

def get_cpu_detector():
    global _dwpose_detector_aligned_cpu
    if _dwpose_detector_aligned_cpu is None:

        device = torch.device('cpu')
        print(f"DWposeDetectorAligned, device={device} (CPU mode)")
        _dwpose_detector_aligned_cpu = Wholebody(device)
    return _dwpose_detector_aligned_cpu

def dwpose_detector_aligned_cpu(oriImg, includeHands=True):
    """
    CPU-only version of pose detector with aligned hands
    """
    detector = get_cpu_detector()
    candidate, subset = detector(oriImg)
    return {
        'bodies': {
            'candidate': candidate, 
            'subset': subset,
            'score': _get_joint_scores(candidate, subset) 
        },
        'faces': _empty_faces(len(subset)),
        'hands': _empty_hands(len(subset) * 2 if includeHands else 0),
        'hands_score': _empty_hands_score(len(subset) * 2 if includeHands else 0)
    }

def _get_joint_scores(candidate, subset):
    """
    Extract confidence scores for each joint of each person
    """
    scores = []
    for person_idx in range(len(subset)):
        person_scores = []
        for joint_idx in range(18):
            idx = int(subset[person_idx][joint_idx])
            if idx >= 0 and idx < len(candidate):
                person_scores.append(candidate[idx][2])
            else:
                person_scores.append(0.0)
        scores.append(person_scores)
    return np.array(scores)

def _empty_faces(num_people):
    """
    Create empty face landmarks placeholder
    """
    return [np.zeros((70, 2)) for _ in range(num_people)]

def _empty_hands(num_hands):
    """
    Create empty hand landmarks placeholder
    """
    return [np.zeros((21, 2)) for _ in range(num_hands)]

def _empty_hands_score(num_hands):
    """
    Create empty hand scores placeholder
    """
    return [np.zeros(21) for _ in range(num_hands)]