import json
import random
import re

import numpy as np
import torch
import torchvision.io
import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
from torch.utils.data import Dataset, DataLoader

from datasets.utils import load_kps_images

def check_tensor(tensor, name):
   print(f"{name} - is_contiguous: {tensor.is_contiguous()}, dtype: {tensor.dtype}, shape: {tensor.shape}")


class TalkingFaceImage(Dataset):
    def __init__(
            self,
            image_size=(512, 512),
            image_scale=(1.0, 1.0),
            image_ratio=(0.9, 1.0),
            meta_paths=None,
            flip_rate=0.0,
            vae_scale_rate=8,
    ):
        super().__init__()

        self.image_size = image_size
        self.flip_rate = flip_rate
        self.vae_scale_rate = vae_scale_rate

        self.images_info = []
        self.captions = {}
        for meta_path in meta_paths:
            obj = json.load(open(meta_path, "r"))
            self.images_info.extend(obj)
        self.default_caption = 'best quality, high quality'

        self.img_transform = transforms.Compose([
            transforms.RandomResizedCrop(
                image_size,
                scale=image_scale,
                ratio=image_ratio,
                interpolation=transforms.InterpolationMode.BILINEAR,
            ),
        ])

    @staticmethod
    def get_face_mask(target_image, face_info):
        face_mask_image = torch.zeros_like(target_image, requires_grad=False)

        bbox = face_info['bbox']
        x1, y1, x2, y2 = bbox
        face_mask_image[:, int(y1):int(y2) + 1, int(x1):int(x2) + 1] = 255

        face_mask_image = transforms_f.resize(face_mask_image, size=[self.image_size[0], self.image_size[1]])

        return face_mask_image

    def process_image(self, reference_image, do_flip, rand_state):
        reference_image = transforms_f.resize(reference_image, size=[self.image_size[0], self.image_size[1]])
        reference_image = transforms_f.to_pil_image(reference_image)
        reference_image = self.augmentation(reference_image, self.img_transform, rand_state)
        if do_flip:
            reference_image = transforms_f.hflip(reference_image)
        reference_image = transforms_f.to_tensor(reference_image)
        reference_image = transforms_f.normalize(reference_image, mean=[.5], std=[.5])
        return reference_image

    def process_kps_images(self, kps_image, do_flip, rand_state):
        kps_image = transforms_f.resize(kps_image, size=[self.image_size[0], self.image_size[1]])
        kps_image = transforms_f.to_pil_image(kps_image)
        kps_image = self.augmentation(kps_image, self.img_transform, rand_state)
        if do_flip:
            kps_image = transforms_f.hflip(kps_image)
        kps_image = transforms_f.to_tensor(kps_image)
        if do_flip:
            # an easy implementation of flipping for kps images
            kps_image = torch.stack([kps_image[1], kps_image[0], kps_image[2]], dim=0)  # RGB -> GRB
        kps_image = kps_image.unsqueeze(0).permute(1, 0, 2, 3)  # [3, num_frames, h, w]
        return kps_image

    def process_masks(self, masks, do_flip, rand_state):
        processed_masks = []
        for mask in masks:
            mask = transforms_f.to_pil_image(mask)
            mask = self.augmentation(mask, self.img_transform, rand_state)
            mask = transforms_f.resize(mask, size=[
                self.image_size[0] // self.vae_scale_rate,
                self.image_size[1] // self.vae_scale_rate
            ])
            if do_flip:
                mask = transforms_f.hflip(mask)
            mask = transforms_f.to_tensor(mask)
            mask = mask[0, ...]
            processed_masks.append(mask)
        masks = torch.stack(processed_masks, dim=0)  # [num_frames, h, w]
        masks = masks.unsqueeze(dim=0)  # [1, num_frames, h, w]
        return masks

    @staticmethod
    def augmentation(image, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        return transform(image)

    def replace_extensions(self, file_path, suffix=''):
        return re.sub(r'\.(jpg|png|jpeg)$', suffix, file_path)

    def __getitem__(self, index):
        flag = True
        while flag:
            transform_rand_state = torch.get_rng_state()
            do_flip = random.random() < self.flip_rate

            image_info = dict(self.images_info[index])
            image_path = image_info["image_file"]
            if image_info.get("text_detailed", None):
                caption = image_info["text_detailed"]
            elif image_info.get("text", None):
                caption = image_info["text"]
            else:
                caption = 'Best quality; high quality'
            face_info = image_info["faces"][0]
            face_embeds_path = image_info.get('face_embeds', None)
            face_embeds_mask_path = image_info.get('face_embeds_mask', None)
            image_id = self.replace_extensions(image_path.split('/')[-1], suffix='')

            # Reference / Target Image
            reference_image = torchvision.io.read_image(image_path)
            height, width = reference_image.shape[1], reference_image.shape[2]
            target_image = reference_image

            # Face Embedding
            ref_face_embed = torch.load(face_embeds_path)[0]
            ref_face_embed_mask = torch.load(face_embeds_path)[0]

            # KPS Image
            kps_image = load_kps_images(
                kps_sequence=face_info['kps'], 
                video_length=1, 
                image_height=height, 
                image_width=width
            )

            try:
                face_mask = self.get_face_mask(target_image, face_info=face_info)
            except:
                face_mask = torch.zeros_like(target_image)
            face_masks = [face_mask]

            reference_image = self.process_image(reference_image, do_flip, transform_rand_state)
            target_image = self.process_image(target_image, do_flip, transform_rand_state)
            target_image = target_image.repeat(1, 1, 1, 1).permute(1, 0, 2, 3) # the second one '1', in order to keep the same as the video dataset, as the num_frames
            kps_image = self.process_kps_images(kps_image[0], do_flip, transform_rand_state)
            face_masks = self.process_masks(face_masks, do_flip, transform_rand_state)
            # lip_mask = self.process_masks([lip_mask], do_flip, transform_rand_state)

            if reference_image.shape != (3, 512, 512):
                print(f'{image_id} Shape of reference_image doesnt match!')
                index += 1
                continue

            if target_image.shape != (3, 1, 512, 512):
                print(f'{image_id} Shape of target_image doesnt match!')
                index += 1
                continue

            if kps_image.shape != (3, 1, 512, 512):
                print(f'{image_id} Shape of kps_image doesnt match!')
                index += 1
                continue

            if ref_face_embed.shape[0] != 512:
                print(f'{image_id} Shape of ref_face_embed doesnt match!')
                index += 1
                continue

            sample = dict(
                reference_image=reference_image,
                # reference_image_mask=refer_face_mask,
                target_images=target_image,
                kps_images=kps_image,
                audio_frame_embeddings=torch.zeros((1, 10, 768)),
                face_masks=face_masks,
                lip_masks=torch.zeros((1, 1, 64, 64)),
                video_id=image_id,
                caption=caption,
                ref_face_embed=ref_face_embed,
                ref_face_embed_mask=ref_face_embed_mask
            )
            return sample

    def __len__(self):
        return len(self.images_info)

if __name__ == "__main__":
    dataset = TalkingFaceImage(
        image_size=(512, 512),
        meta_paths=['/root/datasets/face_data/face1_data_crop2x_id_240626.json'],
        flip_rate=0.0,
    )
    dataloader = DataLoader(dataset, batch_size=4, num_workers=1)
    for item in dataloader:
        for k, v in item.items():
            if isinstance(v, torch.Tensor):
                print(k, type(v), v.shape)
            else:
                print(k, type(v), v)
        print()
        input()
