import json
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from typing import Dict, List, Tuple, Union
import torchvision.transforms.functional as F
import random

def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return transforms.Compose([
        transforms.Resize(n_px, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.CenterCrop(n_px),
        _convert_image_to_rgb,
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
    ])
      
class PeopleDataset(Dataset):
    """
    input_jsonl: {
        'path': full image path,
        'caption': image caption,
        'width': W,
        'height': H,
        'bbox': face bounding box,
        'mask_path': face parsing mask path
    }
    
    face_parsing_mask_pixel_map = {
        'background': 0,
        'face': 23,
        'rb': 46,
        'lb': 69,
        're': 92,
        'le': 115,
        'nose': 139,
        'ulip': 162,
        'imouth': 185,
        'llip': 208,
        'hair': 231,
    }

    只包含image1, caption, 
    返回的数据形式为:
        {
            "image1": torch.Tensor, # resize for diffusion
            "img4clip": torch.Tensor,
            "caption": str,
            "detect": torch.Tensor, # detected image resize to 244 * 244
            "detect_mask": torch.Tensor, # detected mask resize to 244 * 244
            "mask": torch.Tensor, # same size as image1, with mask on it
            "type": str
            "data_type": int, for unidiffuser
        }
    """
    def __init__(self, jsonl_files, repeat=1, train_resolution=512, detect_resolution=224, p_exclude_hair=0, p_flip_reference=0.5):
        super().__init__()
        self.metadatas = self.init_metadatas(jsonl_files)
        self.repeat = repeat
        self.transform_clip = _transform(224)
        resolution = train_resolution
        mask_resolution = train_resolution // 8
        self.p_exclude_hair = p_exclude_hair
        self.p_flip_reference = p_flip_reference 
        self.transform = transforms.Compose([transforms.Resize(resolution),
                                             transforms.CenterCrop(resolution),
                                             transforms.ToTensor(),
                                             transforms.Normalize(0.5, 0.5)])
        self.transform_mask = transforms.Compose([transforms.Resize(mask_resolution),
                                             transforms.CenterCrop(mask_resolution),
                                             transforms.ToTensor()])
        
        self.detect_transform = transforms.Compose([transforms.Resize(detect_resolution),
                                                    transforms.CenterCrop(detect_resolution),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize(0.5, 0.5)])
        self.detect_mask_transform = transforms.Compose([transforms.Resize(detect_resolution),
                                                            transforms.CenterCrop(detect_resolution),
                                                            transforms.ToTensor()])

    def init_metadatas(self, jsonl_files):
        metadatas = []
        for jsonl_file in jsonl_files:
            with open(jsonl_file) as f:
                for line in f:
                    data = json.loads(line)
                    metadatas.append(data)
        return metadatas


    def get_single(self, metadata):
        caption = metadata['caption']
        image1_pil = Image.open(metadata['path']).convert('RGB') ## full img
        image1 = self.transform(image1_pil)
        
        mask_pil = Image.open(metadata['mask_path']).convert("L")
        mask_arr = np.array(mask_pil)
        if random.random() < self.p_exclude_hair:
            mask_arr[mask_arr == 231] = 0
            bbox = metadata['square_bbox_face']
        else:
            bbox = metadata['square_bbox']
            
        mask_arr[mask_arr > 0] = 255
        mask_pil = Image.fromarray(mask_arr)
        mask = self.transform_mask(mask_pil)
        detect_pil = image1_pil.crop(bbox)
        detect_mask_pil = mask_pil.crop(bbox)
        if random.random() < self.p_flip_reference:
            detect_mask_pil =  F.hflip(detect_mask_pil)
            detect_pil = F.hflip(detect_pil)
        detect = self.detect_transform(detect_pil)
        detect_mask = self.detect_mask_transform(detect_mask_pil)
        

        return dict(image1=image1,
                    caption=caption,
                    img4clip=self.transform_clip(image1_pil),
                    mask=mask,
                    detect=detect,
                    detect_mask=detect_mask,
                    type="object",
                    data_type=0)

    def __len__(self):
        return len(self.metadatas) * self.repeat

    def __getitem__(self, idx):
        return self.get_single(self.metadatas[idx % len(self.metadatas)])