import torch
import math
import pathlib

import cv2
import numpy as np
import os
from scipy.ndimage import binary_dilation
from skimage.io import imread, imsave

import imageio
from PIL import Image
from imageio_ffmpeg import get_ffmpeg_exe

import torchvision.transforms as T

tensor_interpolation = None


def get_tensor_interpolation_method():
    return tensor_interpolation


def set_tensor_interpolation_method(is_slerp):
    global tensor_interpolation
    tensor_interpolation = slerp if is_slerp else linear


def linear(v1, v2, t):
    return (1.0 - t) * v1 + t * v2


def slerp(
        v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
) -> torch.Tensor:
    u0 = v0 / v0.norm()
    u1 = v1 / v1.norm()
    dot = (u0 * u1).sum()
    if dot.abs() > DOT_THRESHOLD:
        # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
        return (1.0 - t) * v0 + t * v1
    omega = dot.acos()
    return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()


def interpolate_latents_expand(latents1, latents2, video_length, interpolate_method='slerp'):
    assert latents1.shape == latents2.shape, "The input latents must have the same shape"

    batch_size, channel, height, width = latents1.shape

    interpolated_latents = torch.zeros((batch_size, channel, video_length, height, width), device=latents1.device, dtype=latents1.dtype)

    alphas = torch.linspace(0, 1, steps=video_length, device=latents1.device)

    for i, alpha in enumerate(alphas):
        if interpolate_method != 'slerp':
            interpolated_latents[:, :, i, :, :] = (1 - alpha) * latents1 + alpha * latents2
        else:
            interpolated_latents[:, :, i, :, :] = slerp(latents1, latents2, alpha)

    return interpolated_latents


def load_img(image_path, device):
    image_pil = T.Resize(512)(Image.open(image_path).convert("RGB"))
    image = T.ToTensor()(image_pil).unsqueeze(0).to(device)
    return image


def draw_kps_image(image, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255)], stick_width = 4, w_limbs=True):
    '''
    takes an input image and a set of keypoints (kps), along with an optional color list
    and draws these keypoints and the connections between them onto the image
    '''
    # sets the width of the lines connecting keypoints (stick_width) and defines a sequence of limb connections (limb_seq). 
    # The limb_seq array specifies which keypoints should be connected to draw limbs (e.g., connections between hips, shoulders, etc.). 
    # It assumes a simple configuration here with connections from keypoint 0 to 2 and 1 to 2 (COCO keypoints format).
    limb_seq = np.array([[0, 2], [1, 2]])
    kps = np.array(kps)

    canvas = image
    
    if w_limbs:
        for i in range(len(limb_seq)):
            # For each pair of keypoints defined in limb_seq, it calculates the distance and angle between them.
            # creates an elliptical polygon (representing the limb) that connects the keypoints, with the width set by stick_width.
            # The polygon is filled with a color that corresponds to the start point of the limb connection, slightly darkened (multiplied by 0.6) for aesthetic purposes.
            index = limb_seq[i]
            color = color_list[index[0]]

            x = kps[index][:, 0]
            y = kps[index][:, 1]
            length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
            angle = int(math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])))
            polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stick_width), angle, 0, 360, 1)
            cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color])

    for idx_kp, kp in enumerate(kps):
        color = color_list[idx_kp]
        x, y = kp
        cv2.circle(canvas, (int(x), int(y)), 4, color, -1)

    return canvas


# keypoints prior
def extract_kps_img(
    reference_image_for_kps,
    reference_kps,
    video_len,
    kps_path="",
    retarget_strategy="no_retarget",
    point_kps=False,
    stick_width=4,
    w_limbs=True,
):
    if kps_path != "":
        assert os.path.exists(kps_path), f"{kps_path} does not exist"
        kps_sequence = torch.tensor(torch.load(kps_path))  # [len, 3, 2]
        kps_sequence = torch.nn.functional.interpolate(
            kps_sequence.permute(1, 2, 0), size=video_len, mode="linear"
        )
        kps_sequence = kps_sequence.permute(2, 0, 1)

    # retarget_strategy
    if retarget_strategy == "fix_face":
        kps_sequence = torch.tensor([reference_kps] * video_len)
    elif retarget_strategy == "no_retarget":
        kps_sequence = kps_sequence
    elif retarget_strategy == "offset_retarget":
        kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=True)
    elif retarget_strategy == "naive_retarget":
        kps_sequence = retarget_kps(reference_kps, kps_sequence, only_offset=False)
    else:
        raise ValueError(f"The retarget strategy {retarget_strategy} is not supported.")

    if point_kps:
        kps_sequence = kps_sequence[:, -2:-1, :]
    kps_images = []
    for i in range(video_len):
        kps_image = np.zeros_like(reference_image_for_kps)
        kps_image = draw_kps_image(kps_image, kps_sequence[i], stick_width=stick_width, w_limbs=w_limbs)
        kps_images.append(Image.fromarray(kps_image))

    return kps_images


def video_to_pil_images(video_path):
    """
    Reads a video file and returns a list of PIL Images for each frame.

    :param video_path: Path to the video file.
    :return: List of PIL Images.

    # Usage example
        video_path = 'path/to/your/video.mp4'
        pil_images_list = video_to_pil_images(video_path)
    """
    # Read the video using imageio
    reader = imageio.get_reader(video_path)

    # Convert each frame to a PIL Image and store in a list
    pil_images = [Image.fromarray(frame) for frame in reader]

    return pil_images


def apply_binary_dilation(mask, structure=None, iterations=3):
    """
    Applies binary dilation to a binary mask and returns the result.
    
    Parameters:
    - mask (ndarray): The binary mask to be dilated.
    - structure (ndarray, optional): Structuring element used for dilation. If None, a default 3x3 cross-shaped element is used.
    - iterations (int): Number of times the dilation operation is applied.
    
    Returns:
    - ndarray: The dilated mask.
    """
    if structure is None:
        # Define a default 3x3 cross-shaped structuring element
        structure = np.ones((3, 3), dtype=bool)
    
    dilated_mask = mask
    for _ in range(iterations):
        dilated_mask = binary_dilation(dilated_mask, structure=structure)
    
    return dilated_mask

def load_masked_image_from_faceinfo(image_pil, face_info, unmasked=True, dilation_kernel=0):
    if face_info is None:
        return None

    # Convert PIL image to numpy array and then to tensor
    image_array = np.array(image_pil)
    image_tensor = torch.from_numpy(image_array).float() / 255.0

    # Check if the tensor has less than 3 dimensions and add a channel dimension if necessary
    if len(image_tensor.shape) < 3:
        # If it's grayscale or single-channel, add an extra dimension for channels
        image_tensor = image_tensor.unsqueeze(0)
    else:
        # If it's already in HWC format, convert it to CHW format
        image_tensor = image_tensor.permute(2, 0, 1)

    bbox = face_info['bbox']
    x1, y1, x2, y2 = bbox
    if unmasked:
        face_mask = torch.ones_like(image_tensor)
        face_mask[:, int(y1):int(y2) + 1, int(x1):int(x2) + 1] = 0.
    else:
        face_mask = torch.zeros_like(image_tensor)
        face_mask[:, int(y1):int(y2) + 1, int(x1):int(x2) + 1] = 1.

    if dilation_kernel > 0:
        # Apply binary dilation to the mask
        face_mask_np = face_mask.squeeze().numpy() > 0.5  # Convert to binary mask
        dilated_mask_np = apply_binary_dilation(face_mask_np, iterations=iterations)
        face_mask = torch.from_numpy(dilated_mask_np).unsqueeze(0).float()

    masked_image_tensor = image_tensor * face_mask
    masked_image_tensor_scaled = torch.clamp(masked_image_tensor * 255, 0, 255).type(torch.uint8)
    # Ensure correct shape for PIL conversion
    if masked_image_tensor_scaled.shape[0] == 3:  # RGB
        masked_image_np = masked_image_tensor_scaled.permute(1, 2, 0).numpy()
    elif masked_image_tensor_scaled.shape[0] == 1:  # Grayscale
        masked_image_np = masked_image_tensor_scaled.squeeze().numpy()
    else:
        raise ValueError("Unsupported image tensor shape.")

    masked_image_pil = Image.fromarray(masked_image_np.astype("uint8"))

    return masked_image_pil


def save_video(video_tensor, output_path, audio_path=None, fps=30.0):
    pathlib.Path(output_path).parent.mkdir(exist_ok=True, parents=True)

    video_tensor = video_tensor[0, ...]
    _, num_frames, height, width = video_tensor.shape

    output_name = pathlib.Path(output_path).stem
    temp_output_path = output_path.replace(output_name, output_name + '-temp')
    video_writer = cv2.VideoWriter(temp_output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))

    for i in range(num_frames):
        frame_tensor = video_tensor[:, i, ...]  # [c, h, w]
        frame_tensor = frame_tensor.permute(1, 2, 0)  # [h, w, c]

        frame_image = (frame_tensor * 255).numpy().astype(np.uint8)
        frame_image = cv2.cvtColor(frame_image, cv2.COLOR_RGB2BGR)
        video_writer.write(frame_image)
    video_writer.release()

    if audio_path is not None:
        cmd = (f'{get_ffmpeg_exe()} -i "{temp_output_path}" -i "{audio_path}" '
            f'-map 0:v -map 1:a -c:v h264 -shortest -y "{output_path}" -loglevel quiet')
    else:
        cmd = (f'{get_ffmpeg_exe()} -i "{temp_output_path}" '
            f'-map 0:v -c:v h264 -shortest -y "{output_path}" -loglevel quiet')
    os.system(cmd)
    os.system(f'rm -rf "{temp_output_path}"')


def compute_dist(x1, y1, x2, y2):
    return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)


def compute_ratio(kps):
    l_eye_x, l_eye_y = kps[0][0], kps[0][1]
    r_eye_x, r_eye_y = kps[1][0], kps[1][1]
    nose_x, nose_y = kps[2][0], kps[2][1]
    d_left = compute_dist(l_eye_x, l_eye_y, nose_x, nose_y)
    d_right = compute_dist(r_eye_x, r_eye_y, nose_x, nose_y)
    ratio = d_left / (d_right + 1e-6)
    return ratio


def point_to_line_dist(point, line_points):
    point = np.array(point)
    line_points = np.array(line_points)
    line_vec = line_points[1] - line_points[0]
    point_vec = point - line_points[0]
    line_norm = line_vec / np.sqrt(np.sum(line_vec ** 2))
    point_vec_scaled = point_vec * 1.0 / np.sqrt(np.sum(line_vec ** 2))
    t = np.dot(line_norm, point_vec_scaled)
    if t < 0.0:
        t = 0.0
    elif t > 1.0:
        t = 1.0
    nearest = line_points[0] + t * line_vec
    dist = np.sqrt(np.sum((point - nearest) ** 2))
    return dist


def get_face_size(kps):
    # 0: left eye, 1: right eye, 2: nose
    A = kps[0, :]
    B = kps[1, :]
    C = kps[2, :]

    AB_dist = math.sqrt((A[0] - B[0])**2 + (A[1] - B[1])**2)
    C_AB_dist = point_to_line_dist(C, [A, B])
    return AB_dist, C_AB_dist


def get_rescale_params(kps_ref, kps_target):
    kps_ref = np.array(kps_ref)
    kps_target = np.array(kps_target)

    ref_AB_dist, ref_C_AB_dist = get_face_size(kps_ref)
    target_AB_dist, target_C_AB_dist = get_face_size(kps_target)

    scale_width = ref_AB_dist / target_AB_dist
    scale_height = ref_C_AB_dist / target_C_AB_dist

    return scale_width, scale_height


def retarget_kps(ref_kps, tgt_kps_list, only_offset=True):
    '''
    adjust the keypoint coordinates (tgt_kps_list) to match the scale and optionally the offset of a reference set of keypoints (ref_kps). 
    acting as pose estimation and animation where you want to transfer motion captured from one source to a target character maintaining the same scale and potentially offset. 

    Args:
        reference keypoints (ref_kps) and target keypoints list (tgt_kps_list)

    Returns:
        Rescaled and Offset Keypoints: modified rescaled_tgt_kps_list, which now represents the target keypoints retargeted to match the scale and offset of the reference keypoints as per the specified conditions.

    '''
    ref_kps = np.array(ref_kps)
    tgt_kps_list = np.array(tgt_kps_list)

    ref_ratio = compute_ratio(ref_kps)

    ratio_delta = 10000
    selected_tgt_kps_idx = None
    # calculate a scaling ratio based on the reference keypoints, representing the aspect ratio relevant to the keypoints.
    # Select Closest Target Keypoints Set: Iterates through the tgt_kps_list to find the set of keypoints whose computed ratio is closest to the reference ratio. 
    # This step ensures that the retargeting is based on the most similar scale among the targets.
    for idx, tgt_kps in enumerate(tgt_kps_list):
        tgt_ratio = compute_ratio(tgt_kps)
        if math.fabs(tgt_ratio - ref_ratio) < ratio_delta:
            selected_tgt_kps_idx = idx
            ratio_delta = tgt_ratio

    scale_width, scale_height = get_rescale_params(
        kps_ref=ref_kps,
        kps_target=tgt_kps_list[selected_tgt_kps_idx],
    )

    rescaled_tgt_kps_list = np.array(tgt_kps_list)
    rescaled_tgt_kps_list[:, :, 0] *= scale_width
    rescaled_tgt_kps_list[:, :, 1] *= scale_height

    # Rescale Target Keypoints: All target keypoint sets in tgt_kps_list are rescaled according to scale_width and scale_height.
    # Offset Adjustment:
    #     If only_offset is True, the function adjusts the target keypoints such that the offset from the nose keypoint (assuming index 2 represents the nose) of the first target set matches that of the reference keypoints. It does so by subtracting the nose offset from all target keypoints and then repeating the reference keypoints for each target set.
    #     If only_offset is False, it simply adjusts the entire set of target keypoints by subtracting the x and y offsets of the nose keypoint between the first target set and the reference keypoints.

    if only_offset:
        nose_offset = rescaled_tgt_kps_list[:, 2, :] - rescaled_tgt_kps_list[0, 2, :]
        nose_offset = nose_offset[:, np.newaxis, :]
        ref_kps_repeat = np.tile(ref_kps, (tgt_kps_list.shape[0], 1, 1))

        ref_kps_repeat[:, :, :] -= nose_offset
        rescaled_tgt_kps_list = ref_kps_repeat
    else:
        nose_offset_x = rescaled_tgt_kps_list[0, 2, 0] - ref_kps[2][0]
        nose_offset_y = rescaled_tgt_kps_list[0, 2, 1] - ref_kps[2][1]

        rescaled_tgt_kps_list[:, :, 0] -= nose_offset_x
        rescaled_tgt_kps_list[:, :, 1] -= nose_offset_y

    return rescaled_tgt_kps_list

def get_face_mask(target_image, face_info):
    if isinstance(target_image, Image.Image):
        target_image = T.functional.pil_to_tensor(target_image)
    if isinstance(target_image, np.ndarray):
        target_image = torch.from_numpy(target_image)
    face_mask = torch.zeros_like(target_image)

    bbox = face_info['bbox']
    x1, y1, x2, y2 = bbox
    face_mask[:, int(y1):int(y2) + 1, int(x1):int(x2) + 1] = 255

    # Step 1: If input iamge si s pipl image, convert PIL Image to PyTorch tensor
    image_tensor = target_image.float() / 255.0
    # Step 2: Apply the face mask to the image tensor. The mask is applied by element-wise multiplication.
    masked_image_tensor = image_tensor * face_mask / 255.0
    masked_bg_tensor = image_tensor * (255.0 - face_mask) / 255.0
    # Step 3: convert the result back to a PIL Image
    masked_image_tensor_scaled = torch.clamp(masked_image_tensor * 255, 0, 255).type(torch.uint8)
    masked_bg_tensor_scaled = torch.clamp(masked_bg_tensor * 255, 0, 255).type(torch.uint8)
    if masked_image_tensor.shape[0] == 3:  # RGB
        masked_image_pil = Image.fromarray((masked_image_tensor.permute(1, 2, 0).numpy() * 255).astype('uint8'))
        masked_bg_pil = Image.fromarray((masked_bg_tensor.permute(1, 2, 0).numpy() * 255).astype('uint8'))
    else:
        masked_image_pil = Image.fromarray(masked_image_tensor_scaled.numpy() * 255)
        masked_bg_pil = Image.fromarray(masked_bg_tensor_scaled.numpy() * 255)

    return masked_image_pil, masked_bg_pil, face_mask
