import random
import torch
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
import os
import numpy as np
import cv2


def random_resize(faceparsing_image, low=0.8, high=1.2, low_all = 0.7, high_all = 1.3):
    def is_valid_image(img):
        return img is not None and img.size > 0 and img.shape[0] > 0 and img.shape[1] > 0

    def is_valid_size(size):
        return size[0] > 0 and size[1] > 0

    # 颜色区域
    region_color_dict = {
        0: [(46, 46, 46), (69, 69, 69)], 
        1: [(92, 92, 92), (115, 115, 115)], 
        2: [(139, 139, 139), (139, 139, 139)], 
        3: [(162,162,162), (208, 208, 208)]
    }

    X, Y, W, H = cv2.boundingRect(
        cv2.inRange(faceparsing_image, region_color_dict[0][0], region_color_dict[3][1]))
    center_Y = Y + H // 2
    new_H = int(H * random.uniform(low_all, high_all))
    start_Y = center_Y - new_H // 2
    end_Y = start_Y + new_H
    # 保证新的区域不会超出图片
    while start_Y < 0 or end_Y > faceparsing_image.shape[0]:
        new_H = int(H * random.uniform(low_all, high_all))
        start_Y = center_Y - new_H // 2
        end_Y = start_Y + new_H

    new_regions = []
    for region_idx in range(len(region_color_dict)):
        range_min = region_color_dict[region_idx][0]
        range_max = region_color_dict[region_idx][1]

        # 创建掩码以提取选中区域
        mask = cv2.inRange(faceparsing_image, range_min, range_max)
        selected_region = cv2.bitwise_and(faceparsing_image, faceparsing_image, mask=mask)

        # 找到选中区域的边界框
        x, y, w, h = cv2.boundingRect(mask)

        # 提取并缩放选中区域, 缩放比例为 1.5
        cropped_region = selected_region[y:y+h, x:x+w]
        scale_factor = random.uniform(low, high)
        # print(region_idx, scale_factor)
        new_size = (int(w * scale_factor), int(h * scale_factor))
        if not (is_valid_image(cropped_region) and is_valid_size(new_size)):
            new_regions.append(selected_region)
            continue
        resized_region = cv2.resize(cropped_region, new_size, interpolation=cv2.INTER_LINEAR)

        # 计算放置缩放后选中区域的新位置
        center_x, center_y = x + w // 2, y + h // 2
        ratio = (center_y - Y) / H
        center_y = int(start_Y + new_H * ratio)
        new_x = center_x - new_size[0] // 2
        new_y = center_y - new_size[1] // 2
        # 确保放置区域不超出原图边界, 超出边界就不缩放了
        # end_y = min(new_y + new_size[1], faceparsing_image.shape[0])
        # end_x = min(new_x + new_size[0], faceparsing_image.shape[1])
        # start_y = max(new_y, 0)
        # start_x = max(new_x, 0)
        
        start_y = new_y
        start_x = new_x
        end_y = start_y + new_size[1]
        end_x = start_x + new_size[0]
        if end_y > faceparsing_image.shape[0] or end_x > faceparsing_image.shape[1] or \
           start_y < 0 or start_x < 0:
            start_x = x
            start_y = y
            end_x = start_x + w
            end_y = start_y + h
            resized_region = cropped_region

        new_region = np.zeros_like(faceparsing_image)
        new_region[start_y:end_y, start_x:end_x] = resized_region
        new_regions.append(new_region)

        # faceparsing_image = erased_image + new_region
    result = np.zeros_like(faceparsing_image)
    for new_region in new_regions:
        result += new_region
    return result

class ImageDataset(Dataset):
    def __init__(
        self,
        img_size,    # (512, 512)
        img_scale=(1.0, 1.0),    # (0.9, 1.0)
        img_ratio=(0.9, 1.0),
        drop_ratio=0.1,
        sample_margin=30,
        folder='data', 
        limit=10000
    ):
        super().__init__()

        self.img_size = img_size
        self.img_scale = img_scale
        self.img_ratio = img_ratio
        self.sample_margin = sample_margin
        self.folder = folder
        self.imgs_folder = os.path.join(folder, 'jpgs')
        self.parsings_folder = os.path.join(folder, 'parsing_align_no_contour_new_color')    # 使用没有边界的、有眼部关键点的图片作为动作条件
        self.masks_folder = os.path.join(folder, 'parsings')    # 得到 mask 的图片文件夹, 且需要检查动作文件的数量

        self.clip_image_processor = CLIPImageProcessor()

        self.transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    self.img_size,
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    self.img_size,
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
            ]
        )

        self.drop_ratio = drop_ratio

        # data
        names = sorted(os.listdir(self.imgs_folder))[:limit]
        parsings = []
        valid_names = []
        for name in names:
            _parsings = sorted(os.listdir(os.path.join(self.parsings_folder, name)))
            _masks = sorted(os.listdir(os.path.join(self.masks_folder, name)))
            if len(_parsings) == len(_masks):
                parsings.append(_parsings)
                valid_names.append(name)
        self.names = valid_names
        self.parsings = parsings

        print(f'prepare dataset finish, {len(self.names)} names, {len(self.parsings)} parsings')

    def augmentation(self, image, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        return transform(image)

    def __getitem__(self, index):
        video_name = self.names[index]    # Clip+_-91nXXjrVo+P0+C0+F1537-1825
        imgs_path = os.path.join(self.imgs_folder, video_name)    # data/jpgs/Clip+_-91nXXjrVo+P0+C0+F1537-1825
        parsing_path = os.path.join(self.parsings_folder, video_name)    # data/parsing_align_no_contour/Clip+_-91nXXjrVo+P0+C0+F1537-1825
        parsing_names = self.parsings[index]    # [00000000.png, 00000001.png ......]
        video_length = len(parsing_names)    # 有些帧没有 parsing, 所以用从 parsings 中抽

        margin = min(self.sample_margin, video_length)
        ref_img_idx = random.randint(0, video_length - 1)
        if ref_img_idx + margin < video_length:
            tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
        elif ref_img_idx - margin > 0:
            tgt_img_idx = random.randint(0, ref_img_idx - margin)
        else:
            tgt_img_idx = random.randint(0, video_length - 1)

        ref_img_path = os.path.join(imgs_path, parsing_names[ref_img_idx].split('.')[0] + '.jpg')    # data/jpgs/Clip+_-91nXXjrVo+P0+C0+F1537-1825/00000000.jpg
        ref_img_pil = Image.open(ref_img_path)
        tgt_img_path = os.path.join(imgs_path, parsing_names[tgt_img_idx].split('.')[0] + '.jpg')    # data/jpgs/Clip+_-91nXXjrVo+P0+C0+F1537-1825/00000030.jpg
        tgt_img_pil = Image.open(tgt_img_path)        
        
        # 参考图像姿势要得到 mask, 所以在 parsings 文件夹中获得
        ref_pose_path = os.path.join(self.masks_folder, video_name, 
                                     parsing_names[ref_img_idx])    # data/parsings/Clip+_-91nXXjrVo+P0+C0+F1537-1825/00000000.png
        ref_pose_pil = Image.open(ref_pose_path)
        tgt_pose_path = os.path.join(parsing_path, parsing_names[tgt_img_idx])    # data/parsing_align_no_contour/Clip+_-91nXXjrVo+P0+C0+F1537-1825/00000030.png
        # tgt_pose_pil = Image.open(tgt_pose_path)




        # random resize
        tgt_img = Image.open(tgt_pose_path)
        # print(tgt_pose_path)
        tgt_img = np.array(tgt_img)
        tgt_img = random_resize(tgt_img)
        tgt_pose_pil = Image.fromarray(tgt_img)
        # tgt_pose_pil.save('resized.png')





        state = torch.get_rng_state()
        tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
        tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
        ref_img = self.augmentation(ref_img_pil, self.transform, state)
        ref_pose_img = self.augmentation(ref_pose_pil, self.cond_transform, state)

        mask = torch.where(ref_pose_img == 0, 0, 1)[0]
        mask_np = mask.numpy()[:, :, np.newaxis]
        ref_img_np = np.array(ref_img_pil)
        # print(mask_np.shape, ref_img_np.shape)
        foreground = ref_img_np * mask_np
        background = ref_img_np * (1 - mask_np)
        clip_image_f = self.clip_image_processor(
            images=foreground, return_tensors="pt"
        ).pixel_values[0]
        clip_image_b = self.clip_image_processor(
            images=background, return_tensors="pt"
        ).pixel_values[0]

        sample = dict(
            tgt_img=tgt_img,
            tgt_pose=tgt_pose_img,
            ref_img=ref_img,
            ref_pose=ref_pose_img,
            clip_image_f=clip_image_f,
            clip_image_b=clip_image_b,
        )

        return sample

    def __len__(self):
        return len(self.names)