import json
import random

from PIL import Image
import numpy as np
import torch
import torchvision.io
import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
from torch.utils.data import Dataset, DataLoader


class TalkingFaceVideo(Dataset):
    """
    A dataset class for processing talking video data.

    Args:
        img_size (tuple, optional): The size of the output images. Defaults to (512, 512).
        sample_rate (int, optional): The sample rate of the audio data. Defaults to 16000.
        num_padding_audio_frames, (int, optional): The padding for the audio data. Defaults to 2.
        n_motion_frames (int, optional): The number of motion frames. Defaults to 0.
        num_frames (int, optional): The number of sample frames. Defaults to 16.
        meta_paths (list, optional): The paths to the data metadata. Defaults to None.

    Attributes:
        img_size (tuple): The size of the output images.
        sample_rate (int): The sample rate of the audio data.
        num_padding_audio_frames (int): The padding for the audio data.
        n_motion_frames (int): The number of motion frames.
        num_frames (int): The number of sample frames.
        meta_paths (list): The paths to the data metadata.
    """

    def __init__(
        self,
        image_size=(512, 512),
        image_scale=(1.0, 1.0),
        image_ratio=(0.9, 1.0),
        meta_paths=None,
        prompt_paths=None,
        flip_rate=0.0,
        sample_rate=1,
        num_frames=10,
        reference_margin=30,
        num_padding_audio_frames=2,
        n_motion_frames=0,
        standard_audio_fps=16000,
        vae_scale_rate=8,
        audio_embeddings_interpolation_mode: str = "linear",
    ):
        super().__init__()

        self.image_size = image_size
        self.flip_rate = flip_rate
        self.sample_rate = sample_rate
        self.num_frames = num_frames
        self.reference_margin = reference_margin
        self.num_padding_audio_frames = num_padding_audio_frames
        self.n_motion_frames = n_motion_frames
        self.standard_audio_fps = standard_audio_fps
        self.vae_scale_rate = vae_scale_rate
        self.audio_embeddings_interpolation_mode = audio_embeddings_interpolation_mode

        self.videos_info = []
        self.captions = {}
        for meta_path in meta_paths:
            obj = json.load(open(meta_path, "r"))
            self.videos_info.extend(obj)
        for prompt_path in prompt_paths:
            captions = json.load(open(prompt_path, "r"))
            self.captions.update(captions)
        self.default_caption = "best quality, high quality"

        self.img_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    image_size,
                    scale=image_scale,
                    ratio=image_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
            ]
        )

    def get_audio_frame_embeddings(self, audio_embeddings, frame_ids, video_len):
        # The length of the input audio embeddings is between video_len and 2*video_len
        # 1. interpolate the input audio embeddings into the embeddings with length of 2*vid_len
        audio_embeddings = torch.nn.functional.interpolate(
            audio_embeddings.permute(1, 2, 0),
            size=2 * video_len,
            mode=self.audio_embeddings_interpolation_mode,
        )[0, :, :].permute(
            1, 0
        )  # [2*vid_len, dim]

        # 2. pad zeros to the head and tail of embeddings. NOTE: padding double because of interpolation of 2*vid_len
        audio_embeddings = torch.cat(
            [
                torch.zeros(
                    (2 * self.num_padding_audio_frames, audio_embeddings.shape[-1])
                ),
                audio_embeddings,
                torch.zeros(
                    (2 * self.num_padding_audio_frames, audio_embeddings.shape[-1])
                ),
            ],
            dim=0,
        )

        # 3. select a sequence of audio embeddings to correspond to one video frame.
        audio_frame_embeddings = []
        for frame_idx in frame_ids:
            # Because of zero padding at the head of audio embeddings, the start sample is frame_idx directly.
            start_sample = frame_idx
            end_sample = frame_idx + 2 * self.num_padding_audio_frames
            audio_frame_embeddings.append(
                audio_embeddings[2 * start_sample : 2 * (end_sample + 1), :]
            )
        audio_frame_embeddings = torch.stack(audio_frame_embeddings, dim=0)

        return audio_frame_embeddings

    @staticmethod
    def get_face_mask(target_image, face_info):
        face_mask_image = torch.zeros_like(target_image)

        bbox = face_info["bbox"]
        x1, y1, x2, y2 = bbox
        face_mask_image[:, int(y1) : int(y2) + 1, int(x1) : int(x2) + 1] = 255

        return face_mask_image

    @staticmethod
    def get_lip_mask(target_image, face_info, scale=2.0):
        lip_mask_image = torch.zeros_like(target_image)

        lip_landmarks = face_info["landmark_2d_106"][52:72]
        x1 = int(min(lip_landmarks[:, 0]))
        x2 = int(max(lip_landmarks[:, 0]))
        y1 = int(min(lip_landmarks[:, 1]))
        y2 = int(max(lip_landmarks[:, 1]))
        w = x2 - x1
        h = y2 - y1
        cx = (x1 + x2) / 2
        cy = (y1 + y2) / 2
        x1 = cx - (w / 2) * scale
        x2 = cx + (w / 2) * scale
        y1 = cy - (h / 2) * scale
        y2 = cy + (h / 2) * scale
        lip_mask_image[:, int(y1) : int(y2) + 1, int(x1) : int(x2) + 1] = 255

        return lip_mask_image

    def process_reference_image(self, reference_image, do_flip, rand_state):
        reference_image = transforms_f.to_pil_image(reference_image)
        reference_image = self.augmentation(
            reference_image, self.img_transform, rand_state
        )
        if do_flip:
            reference_image = transforms_f.hflip(reference_image)
        reference_image = transforms_f.to_tensor(reference_image)
        reference_image = transforms_f.normalize(reference_image, mean=[0.5], std=[0.5])
        return reference_image

    def process_target_images(self, target_images, do_flip, rand_state):
        processed_target_images = []
        for target_image in target_images:
            target_image = self.process_reference_image(
                target_image, do_flip, rand_state
            )
            processed_target_images.append(target_image)
        target_images = torch.stack(
            processed_target_images, dim=0
        )  # [num_frames, 3, h, w]
        target_images = target_images.permute(1, 0, 2, 3)  # [3, num_frames, h, w]
        return target_images

    def process_kps_images(self, kps_images, do_flip, rand_state):
        processed_kps_images = []
        for kps_image in kps_images:
            kps_image = transforms_f.to_pil_image(kps_image)
            kps_image = self.augmentation(kps_image, self.img_transform, rand_state)
            if do_flip:
                kps_image = transforms_f.hflip(kps_image)
            kps_image = transforms_f.to_tensor(kps_image)
            if do_flip:
                # an easy implementation of flipping for kps images
                kps_image = torch.stack(
                    [kps_image[1], kps_image[0], kps_image[2]], dim=0
                )  # RGB -> GRB
            processed_kps_images.append(kps_image)
        kps_images = torch.stack(processed_kps_images, dim=0)  # [num_frames, 3, h, w]
        kps_images = kps_images.permute(1, 0, 2, 3)  # [3, num_frames, h, w]
        return kps_images

    def process_masks(self, masks, do_flip, rand_state):
        processed_masks = []
        for mask in masks:
            mask = transforms_f.to_pil_image(mask)
            mask = self.augmentation(mask, self.img_transform, rand_state)
            mask = transforms_f.resize(
                mask,
                size=[
                    self.image_size[0] // self.vae_scale_rate,
                    self.image_size[1] // self.vae_scale_rate,
                ],
            )
            if do_flip:
                mask = transforms_f.hflip(mask)
            mask = transforms_f.to_tensor(mask)
            mask = mask[0, ...]
            processed_masks.append(mask)
        masks = torch.stack(processed_masks, dim=0)  # [num_frames, h, w]
        masks = masks.unsqueeze(dim=0)  # [1, num_frames, h, w]
        return masks

    @staticmethod
    def augmentation(image, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        return transform(image)

    def __getitem__(self, index):
        flag = True
        while flag:
            video_info = dict(self.videos_info[index])
            video_path = video_info["video"]
            kps_video_path = video_info["landmark"]
            face_info_path = video_info["face_info"]
            audio_embeddings_path = video_info["audio_embeds"]
            face_embeds_path = video_info.get("face_embeds", None)
            face_embeds_mask_path = video_info.get("face_embeds_mask", None)
            if "xlsr_audio_embeds" in video_info.keys():
                audio_embeddings_path = video_info["xlsr_audio_embeds"]

            video_id = video_path.split("/")[-1].replace(".mp4", "")
            caption = self.captions.get(video_id, self.default_caption)

            video_frames, audio_waveform, meta_info = torchvision.io.read_video(
                video_path,
                pts_unit="sec",
                output_format="TCHW",
            )
            kps_frames, _, _ = torchvision.io.read_video(
                kps_video_path, pts_unit="sec", output_format="TCHW"
            )
            video_len, kps_len, aud_len = (
                video_frames.shape[0],
                kps_frames.shape[0],
                audio_waveform.shape[1],
            )

            face_info = torch.load(face_info_path)
            face_embeds = (
                torch.load(face_embeds_path)
                if face_embeds_path
                else torch.zeros(video_len, 512)
            )
            face_embeds_mask = (
                torch.load(face_embeds_mask_path)
                if face_embeds_mask_path
                else torch.zeros(video_len, 512)
            )

            assert video_len == kps_len, (
                f"The frame numbers is not equal in {video_path} and {kps_video_path}! "
                f"(video_len is {video_len}, kps_len is {kps_len})"
            )

            if video_len < self.num_frames:
                index += 1
                continue

            clip_video_len = min(
                video_len, (self.num_frames - 1) * self.sample_rate + 1
            )
            start_idx = random.randint(self.n_motion_frames, video_len - clip_video_len)
            batch_ids = np.linspace(
                start_idx, start_idx + clip_video_len - 1, self.num_frames, dtype=int
            ).tolist()

            left_max_reference_idx = min(batch_ids) - self.reference_margin - 1
            right_min_reference_idx = max(batch_ids) + self.reference_margin + 1
            if left_max_reference_idx < 0 and right_min_reference_idx > video_len:
                index += 1
                continue

            reference_idx_range = list(range(video_len))
            remove_ids = np.arange(
                left_max_reference_idx + 1, right_min_reference_idx - 1, dtype=int
            ).tolist()

            for remove_idx in remove_ids:
                if remove_idx not in reference_idx_range:
                    continue
                reference_idx_range.remove(remove_idx)

            reference_idx = random.choice(reference_idx_range)
            reference_image = video_frames[reference_idx, ...]
            ref_face_embed = face_embeds[reference_idx, ...]
            ref_face_embed_mask = face_embeds_mask[reference_idx, ...]

            audio_embeddings = torch.load(audio_embeddings_path, map_location="cpu")
            # Note: 由于之前的音频特征表示包含三个部分，所以这里要选择用全局表示。但是新的 xlsr 表示就不用。
            if "xlsr_audio_embeds" in video_info.keys():
                audio_embeddings = (
                    audio_embeddings.float().detach()
                )  # [num_embeds, 1, dim]
            else:
                audio_embeddings = (
                    audio_embeddings["global_embeds"].float().detach()
                )  # [num_embeds, 1, dim]

            target_images, kps_images, face_masks, lip_masks = [], [], [], []
            for frame_idx in batch_ids:
                target_image = video_frames[frame_idx, ...]
                kps_image = kps_frames[frame_idx, ...]
                try:
                    face_mask = self.get_face_mask(
                        target_image, face_info=face_info[frame_idx][0]
                    )
                except:
                    face_mask = torch.zeros_like(target_image)
                try:
                    lip_mask = self.get_lip_mask(
                        target_image, face_info=face_info[frame_idx][0], scale=2.0
                    )
                except:
                    lip_mask = torch.zeros_like(target_image)

                target_images.append(target_image)
                kps_images.append(kps_image)
                face_masks.append(face_mask)
                lip_masks.append(lip_mask)

            # motion_images = []
            # if self.n_motion_frames > 0:
            #     for frame_idx in range(start_idx - self.n_motion_frames, start_idx):
            #         motion_image = video_frames[frame_idx, ...]  # REVIEW: Prefix Motion
            #         motion_images.append(motion_image)
            # motion_images = self.process_target_images(  # REVIEW: Prefix Motion
            #     motion_images, do_flip, transform_rand_state
            # )

            audio_frame_embeddings = self.get_audio_frame_embeddings(
                audio_embeddings, batch_ids, video_len
            )

            transform_rand_state = torch.get_rng_state()
            do_flip = random.random() < self.flip_rate

            reference_image = self.process_reference_image(
                reference_image, do_flip, transform_rand_state
            )
            target_images = self.process_target_images(
                target_images, do_flip, transform_rand_state
            )
            kps_images = self.process_kps_images(
                kps_images, do_flip, transform_rand_state
            )
            face_masks = self.process_masks(face_masks, do_flip, transform_rand_state)
            # face_masks = torch.stack(face_masks, dim=0)
            lip_masks = self.process_masks(lip_masks, do_flip, transform_rand_state)

            sample = dict(
                reference_image=reference_image,
                # reference_image_mask=refer_face_mask,
                target_images=target_images,
                # motion_images=motion_images,
                kps_images=kps_images,
                audio_frame_embeddings=audio_frame_embeddings,
                face_masks=face_masks,
                lip_masks=lip_masks,
                video_id=video_id,
                caption=caption,
                ref_face_embed=ref_face_embed,
                ref_face_embed_mask=ref_face_embed_mask,
            )
            return sample

    def __len__(self):
        return len(self.videos_info)


if __name__ == "__main__":
    dataset = TalkingFaceVideo(
        image_size=(512, 512),
        image_scale=(1.0, 1.0),
        image_ratio=(0.9, 1.0),
        meta_paths=[
            "/root/multimodal_dataset/animate_image/audio_to_face/part_02/p2.json"
        ],
        flip_rate=0.0,
        sample_rate=1,
        num_frames=1,
        reference_margin=30,
        num_padding_audio_frames=2,
        standard_audio_fps=16000,
        vae_scale_rate=8,
        audio_embeddings_interpolation_mode="linear",
    )
    dataloader = DataLoader(dataset, batch_size=4, num_workers=1)
    for item in dataloader:
        for k in item.keys():
            print(k, item[k].shape)
        print()
        input()
