import json
import random
from typing import List

import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
import os
import cv2


# 颜色区域
region_color_dict = {
    0: [(46, 46, 46), (69, 69, 69)], 
    1: [(92, 92, 92), (115, 115, 115)], 
    2: [(139, 139, 139), (139, 139, 139)], 
    3: [(162,162,162), (208, 208, 208)]
}

def random_resize(faceparsing_image, low=0.8, high=1.2, low_all=0.7, high_all=1.3, scale_all=-1, scale_part=[-1]):
    def is_valid_image(img):
        return img is not None and img.size > 0 and img.shape[0] > 0 and img.shape[1] > 0

    def is_valid_size(size):
        return size[0] > 0 and size[1] > 0

    # 颜色区域
    region_color_dict = {
        0: [(46, 46, 46), (69, 69, 69)], 
        1: [(92, 92, 92), (115, 115, 115)], 
        2: [(139, 139, 139), (139, 139, 139)], 
        3: [(162,162,162), (208, 208, 208)]
    }

    X, Y, W, H = cv2.boundingRect(
        cv2.inRange(faceparsing_image, region_color_dict[0][0], region_color_dict[3][1]))
    center_Y = Y + H // 2
    if scale_all == -1:
        scale_all = random.uniform(low_all, high_all)
        new_H = int(H * scale_all)
        start_Y = center_Y - new_H // 2
        end_Y = start_Y + new_H
        # 保证新的区域不会超出图片
        while start_Y < 0 or end_Y > faceparsing_image.shape[0]:
            scale_all = random.uniform(low_all, high_all)
            new_H = int(H * scale_all)
            start_Y = center_Y - new_H // 2
            end_Y = start_Y + new_H
    else:
        new_H = int(H * scale_all)
        start_Y = center_Y - new_H // 2
        end_Y = start_Y + new_H

    new_regions = []
    scale_factors = []
    for region_idx in range(len(region_color_dict)):
        range_min = region_color_dict[region_idx][0]
        range_max = region_color_dict[region_idx][1]

        # 创建掩码以提取选中区域
        mask = cv2.inRange(faceparsing_image, range_min, range_max)
        selected_region = cv2.bitwise_and(faceparsing_image, faceparsing_image, mask=mask)

        # 找到选中区域的边界框
        x, y, w, h = cv2.boundingRect(mask)

        # 提取并缩放选中区域, 缩放比例为 1.5
        cropped_region = selected_region[y:y+h, x:x+w]
        if scale_part[0] == -1:
            scale_factor = random.uniform(low, high)
        else:
            scale_factor = scale_part[region_idx]
        scale_factors.append(scale_factor)
        # print(region_idx, scale_factor)
        new_size = (int(w * scale_factor), int(h * scale_factor))
        if not (is_valid_image(cropped_region) and is_valid_size(new_size)):
            new_regions.append(selected_region)
            continue
        resized_region = cv2.resize(cropped_region, new_size, interpolation=cv2.INTER_LINEAR)

        # 计算放置缩放后选中区域的新位置
        center_x, center_y = x + w // 2, y + h // 2
        ratio = (center_y - Y) / H
        center_y = int(start_Y + new_H * ratio)
        new_x = center_x - new_size[0] // 2
        new_y = center_y - new_size[1] // 2
        # 确保放置区域不超出原图边界, 超出边界就不缩放了
        # end_y = min(new_y + new_size[1], faceparsing_image.shape[0])
        # end_x = min(new_x + new_size[0], faceparsing_image.shape[1])
        # start_y = max(new_y, 0)
        # start_x = max(new_x, 0)
        
        start_y = new_y
        start_x = new_x
        end_y = start_y + new_size[1]
        end_x = start_x + new_size[0]
        if end_y > faceparsing_image.shape[0] or end_x > faceparsing_image.shape[1] or \
           start_y < 0 or start_x < 0:
            start_x = x
            start_y = y
            end_x = start_x + w
            end_y = start_y + h
            resized_region = cropped_region

        new_region = np.zeros_like(faceparsing_image)
        new_region[start_y:end_y, start_x:end_x] = resized_region
        new_regions.append(new_region)

        # faceparsing_image = erased_image + new_region
    result = np.zeros_like(faceparsing_image)
    for new_region in new_regions:
        result += new_region
    return result, scale_all, scale_factors

class VideoDataset(Dataset):
    def __init__(
        self,
        sample_rate,    # 4
        n_sample_frames,    # 24
        width,    # 512
        height,    # 512
        img_scale=(1.0, 1.0),    # (1.0, 1.0)
        img_ratio=(0.9, 1.0),
        folder='data',
        limit=10000
    ):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_sample_frames = n_sample_frames
        self.width = width
        self.height = height
        self.img_scale = img_scale
        self.img_ratio = img_ratio
        self.folder = folder
        self.imgs_folder = os.path.join(folder, 'jpgs')
        self.parsings_folder = os.path.join(folder, 'parsing_align_no_contour_new_color')
        self.masks_folder = os.path.join(folder, 'parsings')

        self.clip_image_processor = CLIPImageProcessor()

        self.pixel_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    (height, width),
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5]),
            ]
        )

        self.cond_transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    (height, width),
                    scale=self.img_scale,
                    ratio=self.img_ratio,
                    interpolation=transforms.InterpolationMode.BILINEAR,
                ),
                transforms.ToTensor(),
            ]
        )

        names = sorted(os.listdir(self.imgs_folder))[:limit]
        parsings = []
        valid_names = []
        for name in names:
            _parsings = sorted(os.listdir(os.path.join(self.parsings_folder, name)))
            _masks = sorted(os.listdir(os.path.join(self.masks_folder, name)))
            if len(_parsings) == len(_masks) and len(_parsings) > n_sample_frames:
                parsings.append(_parsings)
                valid_names.append(name)
        self.parsings = parsings
        self.names = valid_names

        print(f'prepare dataset finish, {len(self.names)} names, {len(parsings)} parsings')

    def augmentation(self, images, transform, state=None):
        if state is not None:
            torch.set_rng_state(state)
        if isinstance(images, List):
            transformed_images = [transform(img) for img in images]
            ret_tensor = torch.stack(transformed_images, dim=0)  # (f, c, h, w)
        else:
            ret_tensor = transform(images)  # (c, h, w)
        return ret_tensor

    def __getitem__(self, index):
        video_name = self.names[index]    # Clip+_-91nXXjrVo+P0+C0+F1537-1825
        imgs_path = os.path.join(self.imgs_folder, video_name)    # data/jpgs/Clip+_-91nXXjrVo+P0+C0+F1537-1825
        parsing_path = os.path.join(self.parsings_folder, video_name)    # data/parsing_align_no_contour/Clip+_-91nXXjrVo+P0+C0+F1537-1825
        parsing_names = self.parsings[index]    # [00000000.png, 00000001.png ......]
        
        video_length = len(parsing_names)
        # 从 start_idx(比如 1) 开始每 sample_rate 个数取第一个 (1, 5, 9 ...)
        # 因为后面的 np.linspace 的 endpoint 参数使用默认值 True(包括最后一个数)
        # 所以只需要取 n_sample_frames - 1 个数, 最后再取一个就好
        # 如果 video_length 不够, 则 start_idx = 0, 在整个视频序列中取, 间隔就不是 sample_rate 了而是自动计算
        # 取 video_length / n_sample_frames 下取整作为间隔, 最后一个数是 video_lenth - 1
        # 比如 50 帧中取 24 帧, 得到: [0, 2, 4, 6, 8, 10, 12, 14, 17, 19, 21, 23, 25, 27, 29, 31, 34, 36, 38, 40, 42, 44, 46, 49]
        clip_length = min(
            video_length, (self.n_sample_frames - 1) * self.sample_rate + 1
        )
        start_idx = random.randint(0, video_length - clip_length)
        batch_index = np.linspace(
            start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
        ).tolist()

        # read frames and kps
        vid_pil_image_list = []
        pose_pil_image_list = []
        scale_all = -1
        scale_part = [-1]
        for idx in batch_index:
            vid_pil_img_path = os.path.join(imgs_path, parsing_names[idx].split('.')[0] + '.jpg')
            vid_pil_image_list.append(Image.open(vid_pil_img_path))
            # pose_pil_image_path = os.path.join(parsing_path, parsing_names[idx])
            # pose_pil_image_list.append(Image.open(pose_pil_image_path))
            
            # random resize
            pose_pil_image_path = os.path.join(parsing_path, parsing_names[idx])
            pose_pil_image = Image.open(pose_pil_image_path)
            pose_pil_image = np.array(pose_pil_image)
            pose_pil_image, scale_all, scale_part = random_resize(pose_pil_image, scale_all=scale_all, scale_part=scale_part)
            pose_pil_image = Image.fromarray(pose_pil_image)
            pose_pil_image_list.append(pose_pil_image)
            # print(idx, pose_pil_image_path, scale_all, scale_part)
            # pose_pil_image.save(os.path.join('images', str(idx).zfill(5)) + '.png')
        # raise ValueError('aaaaaaaaaaaaaa')
        ref_img_idx = random.randint(0, video_length - 1)
        ref_img = Image.open(os.path.join(imgs_path, parsing_names[ref_img_idx].split('.')[0] + '.jpg'))
        pre_ref_img_idx = random.randint(0, ref_img_idx)
        pre_ref_img = Image.open(os.path.join(imgs_path, parsing_names[pre_ref_img_idx].split('.')[0] + '.jpg'))

        # transform
        state = torch.get_rng_state()
        pixel_values_vid = self.augmentation(
            vid_pil_image_list, self.pixel_transform, state
        )
        pixel_values_pose = self.augmentation(
            pose_pil_image_list, self.cond_transform, state
        )
        pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
        pixel_values_pre_ref_img = self.augmentation(pre_ref_img, self.pixel_transform, state)

        # mask and clip
        ref_pose_path = os.path.join(self.masks_folder, video_name, parsing_names[ref_img_idx])
        ref_pose_pil = Image.open(ref_pose_path)
        ref_pose_img = self.augmentation(ref_pose_pil, self.cond_transform, state)
        mask = torch.where(ref_pose_img == 0, 0, 1)[0]
        mask_np = mask.numpy()[:, :, np.newaxis]
        ref_img_np = np.array(ref_img)
        foreground = ref_img_np * mask_np
        background = ref_img_np * (1 - mask_np)
        clip_ref_img_f = self.clip_image_processor(
            images=foreground, return_tensors="pt"
        ).pixel_values[0]
        clip_ref_img_b = self.clip_image_processor(
            images=background, return_tensors = "pt"
        ).pixel_values[0]

        sample = dict(
            pixel_values_vid=pixel_values_vid,
            pixel_values_pose=pixel_values_pose,
            pixel_values_ref_img=pixel_values_ref_img,
            pixel_values_pre_ref_img = pixel_values_pre_ref_img,
            clip_ref_img_f=clip_ref_img_f,
            clip_ref_img_b=clip_ref_img_b
        )

        return sample

    def __len__(self):
        return len(self.names)
