import gc
import os
import shutil
import subprocess

import cv2
import numpy as np


from .dwpose import util
from .skeleton_info import COCO_JOINT_MAP, COCO_SKELETON_INFO

skeletons = {"coco": COCO_SKELETON_INFO}
joints = {"coco": COCO_JOINT_MAP}


def draw_pose(i, pose, H, W):
    bodies = pose[i]["bodies"]
    candidate = bodies["candidate"].squeeze()
    subset = bodies["subset"][None]
    canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
    canvas = util.draw_body_and_foot(canvas, candidate, subset, stickwidth=2)
    if "faces" in pose[i]:
        faces = pose[i]["faces"]
        canvas = util.draw_face_contour(canvas, faces, thickness=1)
    if "hands" in pose[i]:
        hands = pose[i]["hands"].squeeze()
        canvas = util.draw_handpose(canvas, hands, thickness=1)
    return canvas


def draw_text(canvas, text, position):
    if canvas.shape[1] + canvas.shape[0] < 1000:
        fontScale = 0.5
        thickness = 1
    else:
        fontScale = 1
        thickness = 2
    cv2.putText(
        canvas,
        text,
        position,
        cv2.FONT_HERSHEY_SIMPLEX,
        fontScale,
        (255, 255, 255),
        thickness,
        cv2.LINE_AA,
    )
    return canvas


def generate_pose_video(
    frames,
    bodies_subset,
    skeleton_type,
    image_size,
    video_path,
    fps=10,
    frames_ref=None,
    streaming_size=-1,
    audio_path=None,
    audio_offset=None,
    rescale=False,
):
    """
    Generates a video from pose frames.

    Args:
        features: Pose features to generate the video.
        bodies_subset: Subset of bodies to visualize.
        skeleton_type: Skeleton type for pose drawing.
        image_size: (height, width) of the output video frames.
        video_path: Path to save the generated video.
        fps: Frames per second for the video.
        features_ref: Reference features for side-by-side comparison (optional).
        streaming_size: Optional value to indicate streaming rounds.
    """
    # pose, pose_ref = _prepare_poses(frames, features_ref, bodies_subset)
    frame_dir = _create_frame_dir(video_path)

    for i in range(len(frames)):
        frame_path = os.path.join(frame_dir, f"{i + 1:04d}.png")
        canvas = _generate_frame(i, frames, frames_ref, image_size, streaming_size)
        _save_frame(canvas, frame_path)
        gc.collect()  # Collect garbage after saving each frame to free memory

    _generate_video(frame_dir, fps, video_path, audio_path, audio_offset)
    shutil.rmtree(frame_dir)  # Clean up frame directory after video creation


def _create_frame_dir(video_path):
    """Create directory for storing temporary frames."""
    frame_dir = video_path.replace(".mp4", "")
    os.makedirs(frame_dir, exist_ok=True)
    return frame_dir


def _generate_frame(i, pose, pose_ref, image_size, streaming_size):
    """Generate a single frame, optionally with reference pose."""
    canvas = draw_pose(i, pose, H=image_size[0], W=image_size[1])
    text = "Output"

    if streaming_size != -1:
        text += f" (Round {i//streaming_size:02d})"
    canvas = draw_text(canvas, text, (10, 30))

    if pose_ref is not None:
        canvas_ref = draw_pose(i, pose_ref, H=image_size[0], W=image_size[1])
        text_ref = "GT"
        if streaming_size != -1:
            text_ref += f" (Round {i//streaming_size:02d})"
        canvas_ref = draw_text(canvas_ref, text_ref, (10, 30))
        canvas = np.concatenate([canvas, canvas_ref], axis=1)

    return canvas


def _save_frame(canvas, frame_path):
    """Save a single frame as an image."""
    cv2.imwrite(frame_path, canvas[:, :, ::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 100])


def _generate_video(frame_dir, fps, video_path, audio_path=None, audio_offset=None):
    """Run ffmpeg to compile frames into a video, suppressing ffmpeg output."""
    cmd = [
        "ffmpeg",
        "-y",
        "-f",
        "image2",
        "-loglevel",
        "quiet",
        "-framerate",
        str(fps),
        "-i",
        f"{frame_dir}/%04d.png",
        "-vcodec",
        "libx264",
        "-crf",
        "17",
        "-pix_fmt",
        "yuv420p",
        "-shortest",
        video_path,
    ]
    if audio_path is not None:
        cmd.extend(["-i", audio_path])
    if audio_offset is not None:
        cmd.extend(["-itsoffset", str(audio_offset)])
    # Suppress stdout and stderr to avoid polluting the main log file
    subprocess.run(
        cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
    )
