# coding=utf-8
import os

import PIL
import cv2
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from PIL import Image, ImageDraw
import json

import random
import os.path as osp
import numpy as np
from torch.utils.data import DataLoader
#from datasets.posemap import get_coco_body25_mapping,kpoint_to_heatmap

import matplotlib.pyplot as plt
import torch.nn.functional as F


def mask2bbox(mask):
    up = np.max(np.where(mask)[0])
    down = np.min(np.where(mask)[0])
    left = np.min(np.where(mask)[1])
    right = np.max(np.where(mask)[1])
    center = ((up + down) // 2, (left + right) // 2)

    factor = random.random() * 0.1 + 0.1

    up = int(min(up * (1 + factor) - center[0] * factor + 1, mask.shape[0]))
    down = int(max(down * (1 + factor) - center[0] * factor, 0))
    left = int(max(left * (1 + factor) - center[1] * factor, 0))
    right = int(min(right * (1 + factor) - center[1] * factor + 1, mask.shape[1]))
    return (down, up, left, right)

def strict_mask2bbox(mask):
    up = np.max(np.where(mask)[0])
    down = np.min(np.where(mask)[0])
    left = np.min(np.where(mask)[1])
    right = np.max(np.where(mask)[1])
    return (down, up, left, right)

def show(title, array):
    plt.title(title)
    plt.imshow(array)
    plt.show()
"""
def crop_mask(mask):
    # 找到所有的连通区域
    num_labels, labels = cv2.connectedComponents(mask)

   # 随机选择比例：1/4、1/2 或 3/4
    ratios = [1/8,1/4, 1/2, 3/4]

    # 遍历每个连通区域（跳过背景，label=0）
    for label in range(1, num_labels):
        # 找到当前连通区域的所有位置
        region_indices = np.where(labels == label)
        rows, cols = region_indices

        # 随机选择一个比例
        ratio = np.random.choice(ratios)

        # 找到区域的上面部分
        min_row = np.min(rows)  # 区域的最小行索引
        max_row = np.max(rows)  # 区域的最大行索引
        height = max_row - min_row + 1  # 区域的高度
        split_row = min_row + int(height * ratio)  # 上面部分的分界线

        # 找到上面部分的位置
        upper_rows = rows[rows <= split_row]
        upper_cols = cols[rows <= split_row]

        # 将这些位置的 1 变成 0
        mask[upper_rows, upper_cols] = 0

    return mask
"""
def random_mask_crop_single_region(mask, target_ratio=0.3):
    # 找到所有值为 1 的像素位置
    ones_indices = np.where(mask == 1)
    rows, cols = ones_indices

    # 计算需要覆盖的像素数量
    total_pixels = len(rows)  # 所有值为 1 的像素数量
    target_pixels = int(total_pixels * target_ratio)  # 需要覆盖的像素数量

    # 随机选择 target_pixels 个像素
    random_indices = np.random.choice(total_pixels, size=target_pixels, replace=False)

    # 创建一个新的掩码，避免直接修改原始 mask
    modified_mask = mask.copy()

    # 将这些像素的值从 1 变成 0
    modified_mask[rows[random_indices], cols[random_indices]] = 0

    return modified_mask

def crop_skin_mask(mask):
    # 找到所有的连通区域
    num_labels, labels = cv2.connectedComponents(mask)

    # 定义覆盖区域的概率（上、中、下）
    region_probabilities = [0.5, 0.25, 0.25]  # 概率总和为 1

    # 遍历每个连通区域（跳过背景，label=0）
    for label in range(1, num_labels):
        # 找到当前连通区域的所有位置
        region_indices = np.where(labels == label)
        rows, cols = region_indices

        # 随机选择一个区域（上、中、下）
        region_choice = np.random.choice(["top", "middle", "bottom"], p=region_probabilities)
        #print(region_choice)

        # 找到区域的最小和最大行索引
        min_row = np.min(rows)  # 区域的最小行索引
        max_row = np.max(rows)  # 区域的最大行索引
        height = max_row - min_row + 1  # 区域的高度

        # 根据选择的区域计算分界线
        if region_choice == "top":
            split_row = min_row + int(height * 1/2)  # 上面 1/3
        elif region_choice == "middle":
            split_row_start = min_row + int(height * 1/3)  # 中间 1/3 的起始行
            split_row_end = min_row + int(height * 2/3)  # 中间 1/3 的结束行
        else:  # region_choice == "bottom"
            split_row = min_row + int(height * 2/3)  # 下面 1/3

        # 找到需要覆盖的位置
        if region_choice == "top":
            upper_rows = rows[rows <= split_row]
            upper_cols = cols[rows <= split_row]
        elif region_choice == "middle":
            upper_rows = rows[(rows >= split_row_start) & (rows <= split_row_end)]
            upper_cols = cols[(rows >= split_row_start) & (rows <= split_row_end)]
        else:  # region_choice == "bottom"
            upper_rows = rows[rows >= split_row]
            upper_cols = cols[rows >= split_row]

        # 将这些位置的 1 变成 0
        mask[upper_rows, upper_cols] = 0

    return mask
def crop_main_cloth(mask):
    # 找到 mask 的原始边界
    x, y, w, h = cv2.boundingRect(mask)

    # 计算中心点
    center_x = x + w // 2
    center_y = y + h // 2

    # 定义缩减比例（例如 50%）
    scale = 0.6

    # 计算缩减后的边界
    new_w = int(w * scale)
    new_h = int(h * scale)
    new_x = center_x - new_w // 2
    new_y = center_y - new_h // 2

    # 确保缩减后的边界在图像范围内
    new_x = max(0, new_x)
    new_y = max(0, new_y)
    new_w = min(mask.shape[1] - new_x, new_w)
    new_h = min(mask.shape[0] - new_y, new_h)

    # 创建一个与输入 mask 形状相同的空白掩码
    final_mask = np.zeros_like(mask)

    # 将裁剪后的区域填充到空白掩码中
    if new_w > 0 and new_h > 0:  # 确保裁剪区域有效
        final_mask[new_y:new_y+new_h, new_x:new_x+new_w] = mask[new_y:new_y+new_h, new_x:new_x+new_w]
    return final_mask

import torch
import torch.nn.functional as F
import numpy as np

def align_to_mask(cloth, warp_cloth_mask, cloth_mask):
    """
    将 cloth 和 warp_cloth_mask 对齐到 cloth_mask 的范围内。

    参数:
        cloth (torch.Tensor): 输入的图像张量，形状为 (3, H, W)。
        warp_cloth_mask (torch.Tensor): 输入的掩码张量，形状为 (1, H, W)。
        cloth_mask (torch.Tensor): 目标掩码张量，形状为 (1, H, W)。

    返回:
        aligned_cloth (torch.Tensor): 对齐后的图像张量，形状为 (3, H, W)。
        aligned_mask (torch.Tensor): 对齐后的掩码张量，形状为 (1, H, W)。
    """
    # 检查 cloth_mask 是否有前景区域
    cloth_mask_np = cloth_mask.squeeze(0).numpy()  # 去掉通道维度并转换为 NumPy 数组
    fg_pixels = np.where(cloth_mask_np != 0)
    if len(fg_pixels[0]) == 0:
        raise ValueError("cloth_mask 中没有前景区域（全为零）。")

    # 获取 cloth_mask 的边界框
    t, b = min(fg_pixels[0]), max(fg_pixels[0])  # 上下边界
    l, r = min(fg_pixels[1]), max(fg_pixels[1])  # 左右边界

    # 计算 cloth_mask 的中心点
    mask_center_x = (l + r) // 2
    mask_center_y = (t + b) // 2

    # 检查 warp_cloth_mask 是否有前景区域
    warp_mask_np = warp_cloth_mask.squeeze(0).numpy()  # 去掉通道维度并转换为 NumPy 数组
    warp_fg_pixels = np.where(warp_mask_np != 0)
    if len(warp_fg_pixels[0]) == 0:
        raise ValueError("warp_cloth_mask 中没有前景区域（全为零）。")

    # 获取 warp_cloth_mask 的边界框
    warped_t, warped_b = min(warp_fg_pixels[0]), max(warp_fg_pixels[0])
    warped_l, warped_r = min(warp_fg_pixels[1]), max(warp_fg_pixels[1])

    # 计算 warp_cloth_mask 的中心点
    warped_center_x = (warped_l + warped_r) // 2
    warped_center_y = (warped_t + warped_b) // 2

    # 计算缩放因子
    mask_height = b - t
    mask_width = r - l
    warped_height = warped_b - warped_t
    warped_width = warped_r - warped_l

    if warped_height == 0 or warped_width == 0:
        raise ValueError("warp_cloth_mask 的边界框高度或宽度为零。")
    if mask_height == 0 or mask_width == 0:
        raise ValueError("cloth_mask 的边界框高度或宽度为零。")

    scale_factor = min(mask_height / warped_height, mask_width / warped_width)

    # 计算目标大小
    target_height = int(warped_height * scale_factor)
    target_width = int(warped_width * scale_factor)

    # 缩放 cloth 和 warp_cloth_mask
    cloth = F.interpolate(
        cloth.unsqueeze(0),  # 添加批次维度
        size=(target_height, target_width),
        mode='bilinear',
        align_corners=False
    ).squeeze(0)  # 移除批次维度

    warp_cloth_mask = F.interpolate(
        warp_cloth_mask.unsqueeze(0),  # 添加批次维度
        size=(target_height, target_width),
        mode='nearest'
    ).squeeze(0)  # 移除批次维度

    # 计算粘贴位置
    paste_x = mask_center_x - int(warped_center_x * scale_factor)
    paste_y = mask_center_y - int(warped_center_y * scale_factor)

    # 确保粘贴位置在有效范围内
    paste_x = max(0, min(paste_x, cloth_mask.size(2) - target_width))
    paste_y = max(0, min(paste_y, cloth_mask.size(1) - target_height))

    # 创建空白张量
    aligned_cloth = torch.zeros_like(cloth_mask).repeat(3, 1, 1)  # (3, H, W)
    aligned_mask = torch.zeros_like(cloth_mask)  # (1, H, W)

    # 将缩放后的图像和掩码粘贴到空白张量
    aligned_cloth[:, paste_y:paste_y + target_height, paste_x:paste_x + target_width] = cloth
    aligned_mask[:, paste_y:paste_y + target_height, paste_x:paste_x + target_width] = warp_cloth_mask

    return aligned_cloth, aligned_mask

class CPDataset(data.Dataset):
    """
        Dataset for CP-VTON.
    """

    def __init__(self, dataroot, image_size=512, mode='train',  unpaired=False,semantic_nc=13,
                 caption_folder=None,pred_dataroot=None
                 ):
        super(CPDataset, self).__init__()
        # base setting
        self.pred_dataroot = pred_dataroot
        self.root = dataroot
        self.unpaired = unpaired
        self.datamode = mode  # train or test or self-defined
        self.data_list = mode + '_pairs.txt'
        self.fine_height = image_size
        self.fine_width = int(image_size / 256 * 256)
        self.semantic_nc = semantic_nc
        self.data_path = osp.join(dataroot, mode)
        self.crop_size = (self.fine_height, self.fine_height)
        self.to_tensor = transforms.ToTensor()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        self.clip_normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
                                                   (0.26862954, 0.26130258, 0.27577711))



        #assert not (self.datamode == "train" and self.unpaired), f"train must use paired dataset"  暂时先删除，真正训练需要
       
        self.cloth_caption_folder = caption_folder
        #self.spatial_caption_folder = spatial_caption_folder
        if self.cloth_caption_folder is not None:
            #print(osp.join(dataroot, self.cloth_caption_folder))
            with open(osp.join(dataroot, self.cloth_caption_folder), 'r') as f:
                # self.captions_dict = json.load(f)['items']
                self.cloth_captions_dict = json.load(f)
        """
        if self.spatial_caption_folder is not None:
            with open(osp(dataroot, self.spatial_caption_folder), 'r') as f:
                lines = f.readlines()
                self.spatial_captions_dict = {}
                for line in lines:
                    # 去掉行末的换行符
                    line = line.strip()
                    # 拆分键值对
                    key, value = line.split(':', 1)
                    # 将键值对存储在字典中
                    self.spatial_captions_dict[key] = value
        """
        # load data list
        im_names = []
        c_names = []
        with open(osp.join(dataroot, self.data_list), 'r') as f:
            for line in f.readlines():
                #print(f"line : {line}")
                im_name, c_name = line.strip().split()
                im_names.append(im_name)
                c_names.append(c_name)

        self.im_names = im_names
        self.c_names = dict()
        self.c_names['paired'] = im_names
        self.c_names['unpaired'] = c_names

    def name(self):
        return "CPDataset"
    def get_agnostic(self, im, im_parse, pose_data):
        parse_array = np.array(im_parse)
        parse_head = ((parse_array == 4).astype(np.float32) +
                      (parse_array == 13).astype(np.float32))
        parse_lower = ((parse_array == 9).astype(np.float32) +
                       (parse_array == 12).astype(np.float32) +
                       (parse_array == 16).astype(np.float32) +
                       (parse_array == 17).astype(np.float32) +
                       (parse_array == 18).astype(np.float32) +
                       (parse_array == 19).astype(np.float32))

        agnostic = im.copy()
        agnostic_draw = ImageDraw.Draw(agnostic)

        length_a = np.linalg.norm(pose_data[5] - pose_data[2])
        length_b = np.linalg.norm(pose_data[12] - pose_data[9])
        point = (pose_data[9] + pose_data[12]) / 2
        pose_data[9] = point + (pose_data[9] - point) / length_b * length_a
        pose_data[12] = point + (pose_data[12] - point) / length_b * length_a

        r = int(length_a / 16) + 1

        # mask torso
        for i in [9, 12]:
            pointx, pointy = pose_data[i]
            agnostic_draw.ellipse((pointx - r * 3, pointy - r * 6, pointx + r * 3, pointy + r * 6), 'gray', 'gray')
        agnostic_draw.line([tuple(pose_data[i]) for i in [2, 9]], 'gray', width=r * 6)
        agnostic_draw.line([tuple(pose_data[i]) for i in [5, 12]], 'gray', width=r * 6)
        agnostic_draw.line([tuple(pose_data[i]) for i in [9, 12]], 'gray', width=r * 12)
        agnostic_draw.polygon([tuple(pose_data[i]) for i in [2, 5, 12, 9]], 'gray', 'gray')

        # mask neck
        pointx, pointy = pose_data[1]
        agnostic_draw.rectangle((pointx - r * 5, pointy - r * 9, pointx + r * 5, pointy), 'gray', 'gray')

        # mask arms
        agnostic_draw.line([tuple(pose_data[i]) for i in [2, 5]], 'gray', width=r * 12)
        for i in [2, 5]:
            pointx, pointy = pose_data[i]
            agnostic_draw.ellipse((pointx - r * 5, pointy - r * 6, pointx + r * 5, pointy + r * 6), 'gray', 'gray')
        for i in [3, 4, 6, 7]:
            if (pose_data[i - 1, 0] == 0.0 and pose_data[i - 1, 1] == 0.0) or (
                    pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
                continue
            agnostic_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'gray', width=r * 10)
            pointx, pointy = pose_data[i]
            agnostic_draw.ellipse((pointx - r * 5, pointy - r * 5, pointx + r * 5, pointy + r * 5), 'gray', 'gray')

        for parse_id, pose_ids in [(14, [5, 6, 7]), (15, [2, 3, 4])]:
            # mask_arm = Image.new('L', (self.fine_width, self.fine_height), 'white')
            mask_arm = Image.new('L', (768, 1024), 'white')
            mask_arm_draw = ImageDraw.Draw(mask_arm)
            pointx, pointy = pose_data[pose_ids[0]]
            mask_arm_draw.ellipse((pointx - r * 5, pointy - r * 6, pointx + r * 5, pointy + r * 6), 'black', 'black')
            for i in pose_ids[1:]:
                if (pose_data[i - 1, 0] == 0.0 and pose_data[i - 1, 1] == 0.0) or (
                        pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
                    continue
                mask_arm_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'black', width=r * 10)
                pointx, pointy = pose_data[i]
                if i != pose_ids[-1]:
                    mask_arm_draw.ellipse((pointx - r * 5, pointy - r * 5, pointx + r * 5, pointy + r * 5), 'black',
                                          'black')
            mask_arm_draw.ellipse((pointx - r * 4, pointy - r * 4, pointx + r * 4, pointy + r * 4), 'black', 'black')

            parse_arm = (np.array(mask_arm) / 255) * (parse_array == parse_id).astype(np.float32)
            agnostic.paste(im, None, Image.fromarray(np.uint8(parse_arm * 255), 'L'))

        agnostic.paste(im, None, Image.fromarray(np.uint8(parse_head * 255), 'L'))
        agnostic.paste(im, None, Image.fromarray(np.uint8(parse_lower * 255), 'L'))
        return agnostic


    def __getitem__(self, index):
        im_name = self.im_names[index]
        im_name = 'image/' + im_name
        if self.unpaired:
            key = 'unpaired'
        else:
            key = 'paired'


        # load cloth
        c_name= self.c_names[key][index]
        cloth = Image.open(osp.join(self.data_path, 'cloth', c_name)).convert('RGB')
        cloth = transforms.Resize(self.crop_size, interpolation=2)(cloth)
        cloth = self.transform(cloth)  # [-1,1]

        # load cloth_mask
        cloth_mask = Image.open(osp.join(self.data_path,'cloth-mask', c_name)).convert('L')
        cloth_mask = transforms.Resize(self.crop_size, interpolation=transforms.InterpolationMode.NEAREST) \
            (cloth_mask)

        cloth_mask = self.to_tensor(cloth_mask)
        cloth_mask = (cloth_mask > 0.5).float()
        aug_cloth_mask = cloth_mask[0].numpy() # numpy
        # aug_cloth_mask
        kernel_size = int(5 * (self.fine_width / 256))
        aug_cloth_mask = cv2.dilate(aug_cloth_mask.astype(np.uint8), kernel=np.ones((kernel_size, kernel_size)), iterations=4) # 原来3
        aug_cloth_mask = cv2.erode(aug_cloth_mask.astype(np.uint8), kernel=np.ones((kernel_size, kernel_size)), iterations=1)
        aug_cloth_mask = aug_cloth_mask.astype(np.float32)
        #--------------
        cloth_feat_mask = random_mask_crop_single_region(aug_cloth_mask.astype(np.uint8)).astype(np.float32)
        cloth_feat_mask = self.to_tensor(cloth_feat_mask)
        #----------------------
        aug_cloth_mask = self.to_tensor(aug_cloth_mask)

        #-------------l1 loss need
        cloth_feat = (1 - aug_cloth_mask) * cloth + cloth_feat_mask * cloth

        #------------------------------

       
        



        
        b_down, b_up, b_left, b_right = mask2bbox(cloth_mask[0].numpy())
        ##print(f"cloth mask shape: {b_down,b_up, b_left,b_right}")
        cloth_mask2box = torch.zeros_like(cloth_mask)
        cloth_mask2box[:, b_down:b_up, b_left:b_right] = 1
        inpaint_cloth_mask = cloth_mask2box
        

        



        pred_im_name = c_name
        pred_im = Image.open(osp.join(self.pred_dataroot, pred_im_name.replace('.jpg', '.png'))).convert('RGB')
        pred_im = self.transform(pred_im)



       
        result = {
            "im_name": self.c_names[key][index],
            "im":cloth,
            "inpaint_mask":aug_cloth_mask,
            'pred_im':pred_im,

        }
        return result

    def __len__(self):
        return len(self.im_names)



if __name__ == '__main__':
    from tqdm import tqdm
    from einops import rearrange
    dataset = CPDataset('DATA/VITON-HD', 512, mode='test', unpaired=False,pred_dataroot='outputs/Tryon-samples/all_pair_VITOHCloth_202_499_predClothMask copy')

    loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
    iterator = tqdm(loader, desc='Test Dataset', total=len(loader))
    result_path = os.path.join('outputs', "all_NoCatVitonCloth_inpaintMask") # "upper_body"
    os.makedirs(result_path, exist_ok=True)
    for data in iterator:
        x_im = torch.clamp((data['im'] + 1.0) / 2.0, min=0.0, max=1.0)
        x_pred = torch.clamp((data['pred_im'] + 1.0) / 2.0, min=0.0, max=1.0)
        x_mask = data['inpaint_mask']
        x_result = x_im * (1 - x_mask) + x_mask * x_pred
        #resize = transforms.Resize((512, 384))

        for i, x_sample in enumerate(x_result):
            filename = data['im_name'][i]
            # filename = data['file_name']
            #save_x = tensor2img(resize(x_sample))
            #save_x = resize(x_sample)
            save_x = x_sample
            save_x = 255. * rearrange(save_x.cpu().numpy(), 'c h w -> h w c')
            img = Image.fromarray(save_x.astype(np.uint8))
            img.save(os.path.join(result_path, filename[:-4] + ".png"))
    
       # break
        