import sys
project_path = "./projects/DWPose/ControlNet-v1-1-nightly"
if project_path not in sys.path:
    sys.path.insert(0, project_path)

from annotator.dwpose import DWposeDetector
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import json
from moviepy.editor import ImageSequenceClip

def list_jpg_files(folder_path):
    jpg_files = []
    for root, dirs, files in os.walk(folder_path):
        files.sort()
        for file in files:
            if file.lower().endswith('.jpg') or file.lower().endswith('.png'):
                jpg_files.append(os.path.join(root, file))
    return jpg_files

def numpy_to_list(np_array, H, W):
    res_list = []
    tmp_list = np_array.tolist()
    for item in tmp_list[0]:
        if item[0] >= W-1 or item[0] <= 0 or item[1] >= H-1 or item[1] <= 0 or item[2] <= 0.3:
            item[0], item[1], item[2] = 0, 0, 0
        else:
            item[0] = int(item[0])
            item[1] = int(item[1])
            item[2] = round(item[2], 4)
        res_list.extend(item)
    return res_list

def calculate_bounding_box(keypoints):
    x_coords = [keypoints[i] for i in range(0, len(keypoints), 3) if keypoints[i] > 0]
    y_coords = [keypoints[i + 1] for i in range(0, len(keypoints), 3) if keypoints[i + 1] > 0]
    if not x_coords or not y_coords:
        return 0, 0, 0, 0  # No valid keypoints
    min_x = min(x_coords)
    min_y = min(y_coords)
    max_x = max(x_coords)
    max_y = max(y_coords)
    width = max_x - min_x
    height = max_y - min_y
    return min_x, min_y, width, height

def calculate_bbox_area(bbox):
    _, _, width, height = bbox
    return width * height

def cal_valid_kpts(keypoints):
    valid_kpts = 0
    for i in range(0, len(keypoints), 3):
        if keypoints[i+2] > 0.3:
            valid_kpts += 1
    return valid_kpts

def visualize_bboxes(img, bbox, face_box, lefthand_box, righthand_box):
    output_img = img.copy()
    # Draw bbox
    cv2.rectangle(output_img, (int(bbox[0]), int(bbox[1])), 
                  (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), 
                  (255, 0, 0), 2)  # Blue box for overall bbox
    # Draw face_box
    if face_box is not None:
        cv2.rectangle(output_img, (int(face_box[0]), int(face_box[1])), 
                  (int(face_box[0] + face_box[2]), int(face_box[1] + face_box[3])), 
                  (0, 255, 0), 2)  # Green box for face
    # Draw lefthand_box
    if lefthand_box is not None:
        cv2.rectangle(output_img, (int(lefthand_box[0]), int(lefthand_box[1])), 
                  (int(lefthand_box[0] + lefthand_box[2]), int(lefthand_box[1] + lefthand_box[3])), 
                  (0, 0, 255), 2)  # Red box for left hand
    # Draw righthand_box
    if righthand_box is not None:
        cv2.rectangle(output_img, (int(righthand_box[0]), int(righthand_box[1])), 
                  (int(righthand_box[0] + righthand_box[2]), int(righthand_box[1] + righthand_box[3])), 
                  (255, 255, 0), 2)  # Cyan box for right hand
    return output_img

def process_image(pose_model, img):
    H, W, _ = img.shape
    keypoints_info = pose_model(img)
    body = keypoints_info[:,:17] 
    foot = keypoints_info[:,17:23] 
    faces = keypoints_info[:,23:91]
    lefthand = keypoints_info[:,91:112] 
    righthand = keypoints_info[:,112:]
            
    keypoints = numpy_to_list(body, H, W)
    foot_kpts = numpy_to_list(foot, H, W)
    face_kpts = numpy_to_list(faces, H, W)
    lefthand_kpts = numpy_to_list(lefthand, H, W)
    righthand_kpts = numpy_to_list(righthand, H, W)
        
    # area = calculate_bounding_box_area(keypoints)
    bbox = calculate_bounding_box(keypoints + face_kpts + lefthand_kpts + righthand_kpts + foot_kpts)
    all_area = calculate_bbox_area(bbox)
    # print(H*W)
    # print(f"bbox: {bbox}, all_area: {all_area}")
    
    face_box = calculate_bounding_box(face_kpts)
    face_area = calculate_bbox_area(face_box)
    # print(f"face_box: {face_box}, face_area: {face_area}")

    lefthand_box = calculate_bounding_box(lefthand_kpts) # Red box
    lefthand_area = calculate_bbox_area(lefthand_box)
    # print(f"lefthand_box: {lefthand_box}, lefthand_area: {lefthand_area}")

    righthand_box = calculate_bounding_box(righthand_kpts) # Cyan box
    righthand_area = calculate_bbox_area(righthand_box)
    # print(f"righthand_box: {righthand_box}, righthand_area: {righthand_area}")
        
    if cal_valid_kpts(face_kpts) < 68*0.3 or face_area < all_area*0.005:
        face_box = None
    if cal_valid_kpts(lefthand_kpts) < 21*0.3 or lefthand_area < all_area*0.005:
        lefthand_box = None
    if cal_valid_kpts(righthand_kpts) < 21*0.3 or righthand_area < all_area*0.005:
        righthand_box = None

    return bbox, face_box, lefthand_box, righthand_box

def resize_connected_components(image, scale_factor):

    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    output = np.zeros_like(image)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        gray, connectivity=8, ltype=cv2.CV_32S
    )
    for i in range(1, num_labels):
        x = stats[i, cv2.CC_STAT_LEFT]
        y = stats[i, cv2.CC_STAT_TOP]
        w = stats[i, cv2.CC_STAT_WIDTH]
        h = stats[i, cv2.CC_STAT_HEIGHT]
        component_mask = (labels == i).astype(np.uint8)
        component_region = image[y:y+h, x:x+w] * component_mask[y:y+h, x:x+w, np.newaxis]
        new_w = int(w * scale_factor)
        new_h = int(h * scale_factor)
        cx = x + w // 2
        cy = y + h // 2
        new_x = max(0, cx - new_w // 2)
        new_y = max(0, cy - new_h // 2)
        new_w = min(new_w, image.shape[1] - new_x)
        new_h = min(new_h, image.shape[0] - new_y)
        resized_region = cv2.resize(component_region, (new_w, new_h), interpolation=cv2.INTER_CUBIC)
        output[new_y:new_y+new_h, new_x:new_x+new_w] = np.maximum(output[new_y:new_y+new_h, new_x:new_x+new_w], resized_region)
    return output

def visualize_masks(img, face_box, lefthand_box, righthand_box, expand_ratio=1.2):
    mask = np.zeros(img.shape[:2], dtype=np.uint8)
    def expand_box(box, expand_ratio, img_shape):
        if box is None:
            return None
        x, y, w, h = box
        cx, cy = x + w / 2, y + h / 2
        new_w, new_h = w * expand_ratio, h * expand_ratio
        new_x = max(0, int(cx - new_w / 2))
        new_y = max(0, int(cy - new_h / 2))
        new_w = min(int(new_w), img_shape[1] - new_x)
        new_h = min(int(new_h), img_shape[0] - new_y)
        return new_x, new_y, new_w, new_h
    for box in [face_box, lefthand_box, righthand_box]:
        expanded_box = expand_box(box, expand_ratio, img.shape[:2])
        if expanded_box is not None:
            x, y, w, h = expanded_box
            mask[y:y+h, x:x+w] = 255
    mask = cv2.merge([mask, mask, mask])
    return mask

def visualize_masks_separately(img, face_box, lefthand_box, righthand_box, expand_ratio=1.5):
    mask = np.zeros(img.shape[:2], dtype=np.uint8)
    def expand_box(box, expand_ratio, img_shape):
        if box is None:
            return None
        x, y, w, h = box
        cx, cy = x + w / 2, y + h / 2
        new_w, new_h = w * expand_ratio, h * expand_ratio
        new_x = max(0, int(cx - new_w / 2))
        new_y = max(0, int(cy - new_h / 2))
        new_w = min(int(new_w), img_shape[1] - new_x)
        new_h = min(int(new_h), img_shape[0] - new_y)
        return new_x, new_y, new_w, new_h
    for box in [face_box]:
        expanded_box = expand_box(box, expand_ratio, img.shape[:2])
        if expanded_box is not None:
            x, y, w, h = expanded_box
            mask[y:y+h, x:x+w] = 255
    # for box in [lefthand_box, righthand_box]:
    #     expanded_box = expand_box(box, expand_ratio, img.shape[:2])
    #     if expanded_box is not None:
    #         x, y, w, h = expanded_box
    #         mask[y:y+h, x:x+w] = 128
    mask = cv2.merge([mask, mask, mask])
    # mask = resize_connected_components(mask, scale_factor=2)
    return mask

def visualize_face_hand(img, face_box, lefthand_box, righthand_box, expand_ratio=1.5):
    output_img = np.zeros_like(img)
    def expand_box(box, expand_ratio, img_shape):
        if box is None:
            return None
        x, y, w, h = box
        cx, cy = x + w / 2, y + h / 2
        new_w, new_h = w * expand_ratio, h * expand_ratio
        new_x = max(0, int(cx - new_w / 2))
        new_y = max(0, int(cy - new_h / 2))
        new_w = min(int(new_w), img_shape[1] - new_x)
        new_h = min(int(new_h), img_shape[0] - new_y)
        return new_x, new_y, new_w, new_h
    # for box in [face_box, lefthand_box, righthand_box]:
    for box in [face_box]:
        expanded_box = expand_box(box, expand_ratio, img.shape[:2])
        if expanded_box is not None:
            x, y, w, h = expanded_box
            output_img[y:y+h, x:x+w] = img[y:y+h, x:x+w]
    output_img = resize_connected_components(output_img, scale_factor=2)
    return output_img

def process_video(input_video_path, output_path, pose_model):
    # Read the video
    cap = cv2.VideoCapture(input_video_path)
    frame_list_mask = []
    frame_list_face_hand = []
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"Total frames in video: {frame_count}")
    frame_idx = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        print(f"Processing frame {frame_idx + 1}/{frame_count}")
        bbox, face_box, lefthand_box, righthand_box = process_image(pose_model, frame)
        # ============== visualize RGB video ==============
        # visualized_frame = visualize_bboxes(frame, bbox, face_box, lefthand_box, righthand_box)
        # frame_list.append(cv2.cvtColor(visualized_frame, cv2.COLOR_BGR2RGB))  # Convert to RGB for moviepy
        # ============== visualize mask ==============
        # visualized_frame = visualize_masks(frame, face_box, lefthand_box, righthand_box)
        visualized_mask = visualize_masks_separately(frame, face_box, lefthand_box, righthand_box)
        frame_list_mask.append(cv2.cvtColor(visualized_mask, cv2.COLOR_BGR2RGB))
        # ============== visualize face hand ==============
        visualized_face_hand = visualize_face_hand(frame, face_box, lefthand_box, righthand_box)
        frame_list_face_hand.append(cv2.cvtColor(visualized_face_hand, cv2.COLOR_BGR2RGB))
        frame_idx += 1
    cap.release()
    # Save the frames as a video using moviepy
    # ======== write driving_face_mask video =========
    clip = ImageSequenceClip(frame_list_mask, fps=30)  # Adjust FPS as needed
    driving_face_mask_path = os.path.join(output_path, "driving_face_mask", input_video_path.split(os.sep)[-1])
    if not os.path.exists(os.path.dirname(driving_face_mask_path)):
        os.makedirs(os.path.dirname(driving_face_mask_path))
    clip.write_videofile(driving_face_mask_path, codec="libx264")
    # ======== write driving_mask video =========
    clip = ImageSequenceClip(frame_list_face_hand, fps=30)  # Adjust FPS as needed
    driving_face_path = os.path.join(output_path, "driving_face", input_video_path.split(os.sep)[-1])
    if not os.path.exists(os.path.dirname(driving_face_path)):
        os.makedirs(os.path.dirname(driving_face_path))
    clip.write_videofile(driving_face_path, codec="libx264")


if __name__ == "__main__":
    
    pose_model = DWposeDetector()
    # ======================== process video ========================
    input_video_path = ""
    output_path = ""
    process_video(input_video_path, output_path, pose_model)
