# coding=utf-8
import os

import PIL
import cv2
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from PIL import Image, ImageDraw
import json

import random
import os.path as osp
import numpy as np
from torch.utils.data import DataLoader
#from datasets.posemap import get_coco_body25_mapping,kpoint_to_heatmap

import matplotlib.pyplot as plt
import torch.nn.functional as F


def mask2bbox(mask):
    up = np.max(np.where(mask)[0])
    down = np.min(np.where(mask)[0])
    left = np.min(np.where(mask)[1])
    right = np.max(np.where(mask)[1])
    center = ((up + down) // 2, (left + right) // 2)

    factor = random.random() * 0.1 + 0.1

    up = int(min(up * (1 + factor) - center[0] * factor + 1, mask.shape[0]))
    down = int(max(down * (1 + factor) - center[0] * factor, 0))
    left = int(max(left * (1 + factor) - center[1] * factor, 0))
    right = int(min(right * (1 + factor) - center[1] * factor + 1, mask.shape[1]))
    return (down, up, left, right)

def show(title, array):
    plt.title(title)
    plt.imshow(array)
    plt.show()


import torch
import torch.nn.functional as F
import numpy as np

def align_to_mask(cloth, warp_cloth_mask, cloth_mask):
    """
    将 cloth 和 warp_cloth_mask 对齐到 cloth_mask 的范围内。

    参数:
        cloth (torch.Tensor): 输入的图像张量，形状为 (3, H, W)。
        warp_cloth_mask (torch.Tensor): 输入的掩码张量，形状为 (1, H, W)。
        cloth_mask (torch.Tensor): 目标掩码张量，形状为 (1, H, W)。

    返回:
        aligned_cloth (torch.Tensor): 对齐后的图像张量，形状为 (3, H, W)。
        aligned_mask (torch.Tensor): 对齐后的掩码张量，形状为 (1, H, W)。
    """
    # 检查 cloth_mask 是否有前景区域
    cloth_mask_np = cloth_mask.squeeze(0).numpy()  # 去掉通道维度并转换为 NumPy 数组
    fg_pixels = np.where(cloth_mask_np != 0)
    if len(fg_pixels[0]) == 0:
        raise ValueError("cloth_mask 中没有前景区域（全为零）。")

    # 获取 cloth_mask 的边界框
    t, b = min(fg_pixels[0]), max(fg_pixels[0])  # 上下边界
    l, r = min(fg_pixels[1]), max(fg_pixels[1])  # 左右边界

    # 计算 cloth_mask 的中心点
    mask_center_x = (l + r) // 2
    mask_center_y = (t + b) // 2

    # 检查 warp_cloth_mask 是否有前景区域
    warp_mask_np = warp_cloth_mask.squeeze(0).numpy()  # 去掉通道维度并转换为 NumPy 数组
    warp_fg_pixels = np.where(warp_mask_np != 0)
    if len(warp_fg_pixels[0]) == 0:
        raise ValueError("warp_cloth_mask 中没有前景区域（全为零）。")

    # 获取 warp_cloth_mask 的边界框
    warped_t, warped_b = min(warp_fg_pixels[0]), max(warp_fg_pixels[0])
    warped_l, warped_r = min(warp_fg_pixels[1]), max(warp_fg_pixels[1])

    # 计算 warp_cloth_mask 的中心点
    warped_center_x = (warped_l + warped_r) // 2
    warped_center_y = (warped_t + warped_b) // 2

    # 计算缩放因子
    mask_height = b - t
    mask_width = r - l
    warped_height = warped_b - warped_t
    warped_width = warped_r - warped_l

    if warped_height == 0 or warped_width == 0:
        raise ValueError("warp_cloth_mask 的边界框高度或宽度为零。")
    if mask_height == 0 or mask_width == 0:
        raise ValueError("cloth_mask 的边界框高度或宽度为零。")

    scale_factor = min(mask_height / warped_height, mask_width / warped_width)

    # 计算目标大小
    target_height = int(warped_height * scale_factor)
    target_width = int(warped_width * scale_factor)

    # 缩放 cloth 和 warp_cloth_mask
    cloth = F.interpolate(
        cloth.unsqueeze(0),  # 添加批次维度
        size=(target_height, target_width),
        mode='bilinear',
        align_corners=False
    ).squeeze(0)  # 移除批次维度

    warp_cloth_mask = F.interpolate(
        warp_cloth_mask.unsqueeze(0),  # 添加批次维度
        size=(target_height, target_width),
        mode='nearest'
    ).squeeze(0)  # 移除批次维度

    # 计算粘贴位置
    paste_x = mask_center_x - int(warped_center_x * scale_factor)
    paste_y = mask_center_y - int(warped_center_y * scale_factor)

    # 确保粘贴位置在有效范围内
    paste_x = max(0, min(paste_x, cloth_mask.size(2) - target_width))
    paste_y = max(0, min(paste_y, cloth_mask.size(1) - target_height))

    # 创建空白张量
    aligned_cloth = torch.zeros_like(cloth_mask).repeat(3, 1, 1)  # (3, H, W)
    aligned_mask = torch.zeros_like(cloth_mask)  # (1, H, W)

    # 将缩放后的图像和掩码粘贴到空白张量
    aligned_cloth[:, paste_y:paste_y + target_height, paste_x:paste_x + target_width] = cloth
    aligned_mask[:, paste_y:paste_y + target_height, paste_x:paste_x + target_width] = warp_cloth_mask

    return aligned_cloth, aligned_mask

class CPDataset(data.Dataset):
    """
        Dataset for CP-VTON.
    """

    def __init__(self, dataroot, image_size=512, mode='train',  unpaired=False,semantic_nc=13,
                 caption_folder=None
                 ):
        super(CPDataset, self).__init__()
        # base setting
        self.root = dataroot
        self.unpaired = unpaired
        self.datamode = mode  # train or test or self-defined
        self.data_list = mode + '_pairs.txt'
        self.fine_height = image_size
        self.fine_width = int(image_size / 256 * 256)
        self.semantic_nc = semantic_nc
        self.data_path = osp.join(dataroot, mode)
        self.crop_size = (self.fine_height, self.fine_height)
        self.to_tensor = transforms.ToTensor()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        


        #assert not (self.datamode == "train" and self.unpaired), f"train must use paired dataset"  暂时先删除，真正训练需要
       
        self.cloth_caption_folder = caption_folder
        #self.spatial_caption_folder = spatial_caption_folder
        if self.cloth_caption_folder is not None:
            #print(osp.join(dataroot, self.cloth_caption_folder))
            with open(osp.join(dataroot, self.cloth_caption_folder), 'r') as f:
                # self.captions_dict = json.load(f)['items']
                self.cloth_captions_dict = json.load(f)
        """
        if self.spatial_caption_folder is not None:
            with open(osp(dataroot, self.spatial_caption_folder), 'r') as f:
                lines = f.readlines()
                self.spatial_captions_dict = {}
                for line in lines:
                    # 去掉行末的换行符
                    line = line.strip()
                    # 拆分键值对
                    key, value = line.split(':', 1)
                    # 将键值对存储在字典中
                    self.spatial_captions_dict[key] = value
        """
        # load data list
        im_names = []
        c_names = []
        with open(osp.join(dataroot, self.data_list), 'r') as f:
            for line in f.readlines():
                #print(f"line : {line}")
                im_name, c_name = line.strip().split()
                im_names.append(im_name)
                c_names.append(c_name)

        self.im_names = im_names
        self.c_names = dict()
        self.c_names['paired'] = im_names
        self.c_names['unpaired'] = c_names

    def name(self):
        return "CPDataset"


    def __getitem__(self, index):
        im_name = self.im_names[index]
        im_name = 'image/' + im_name
        if self.unpaired:
            key = 'unpaired'
        else:
            key = 'paired'


        # load cloth
        c_name= self.c_names[key][index]
        cloth = Image.open(osp.join(self.data_path, 'cloth', c_name)).convert('RGB')
        cloth = transforms.Resize(self.crop_size, interpolation=2)(cloth)
        cloth = self.transform(cloth)  # [-1,1]

        # load image
        im_pil_big = Image.open(osp.join(self.data_path, im_name)).convert('RGB')
        im_pil = transforms.Resize(self.crop_size, interpolation=2)(im_pil_big)
        im = self.transform(im_pil)

        # load agn_mask
        agn_mask_name = im_name.replace('image', 'agnostic-mask').replace('.jpg', '_mask.png')
        agn_mask_big = Image.open(osp.join(self.data_path, agn_mask_name)).convert('L')
        agn_mask = transforms.Resize(self.crop_size, interpolation=transforms.InterpolationMode.NEAREST)(agn_mask_big)
        agn_mask = self.to_tensor(agn_mask)
        agn_mask = (agn_mask > 0.5).float()
        #print(f"agn_mask sum : {agn_mask.shape}, sum {agn_mask.sum()}")

        # load cloth_mask
        #cloth_mask_name = c_name.replace('image','cloth-mask')  ##### 这里是cloth-warp-mask
        cloth_mask = Image.open(osp.join(self.data_path,'cloth-mask', c_name)).convert('L')
        cloth_mask = transforms.Resize(self.crop_size, interpolation=transforms.InterpolationMode.NEAREST) \
            (cloth_mask)

        cloth_mask = self.to_tensor(cloth_mask)
        cloth_mask = (cloth_mask > 0.5).float()
        #print(f"warped_cloth:{warped_cloth_mask.shape}, sum {warped_cloth_mask.sum()}")
        #---------------------------------------------------------------
        # load parsing image
        parse_name = im_name.replace('image', 'image-parse-v3').replace('.jpg', '.png')
        im_parse_pil_big = Image.open(osp.join(self.data_path, parse_name))
        im_parse_pil = transforms.Resize(self.crop_size, interpolation=0)(im_parse_pil_big)
        parse = torch.from_numpy(np.array(im_parse_pil)[None]).long()

        # parse map
        labels = {
            0: ['background', [0, 10]],
            1: ['hair', [1, 2]],
            2: ['face', [4, 13]],
            3: ['upper', [5, 6, 7]],
            4: ['bottom', [9, 12]],
            5: ['left_arm', [14]],
            6: ['right_arm', [15]],
            7: ['left_leg', [16]],
            8: ['right_leg', [17]],
            9: ['left_shoe', [18]],
            10: ['right_shoe', [19]],
            11: ['socks', [8]],
            12: ['noise', [3, 11]]
        }

        parse_map = torch.FloatTensor(20, self.fine_height, self.fine_width).zero_()
        parse_map = parse_map.scatter_(0, parse, 1.0)
        new_parse_map = torch.FloatTensor(self.semantic_nc, self.fine_height, self.fine_width).zero_()

        for i in range(len(labels)):
            for label in labels[i][1]:
                new_parse_map[i] += parse_map[label]

        parse_onehot = torch.FloatTensor(1, self.fine_height, self.fine_width).zero_()
        for i in range(len(labels)):
            for label in labels[i][1]:
                parse_onehot[0] += parse_map[label] * i

        mask_id = torch.Tensor([3])
        mask = torch.isin(parse_onehot[0], mask_id).numpy()

        kernel_size = int(5 * (self.fine_width / 256))
        mask = cv2.dilate(mask.astype(np.uint8), kernel=np.ones((kernel_size, kernel_size)), iterations=3)
        mask = cv2.erode(mask.astype(np.uint8), kernel=np.ones((kernel_size, kernel_size)), iterations=1)
        mask = mask.astype(np.float32)
        gt_cloth_mask = self.to_tensor(mask)



        # warp cloth
        warped_cloth_name = im_name.replace('image', 'cloth-warp' if not self.unpaired else 'unpaired-cloth-warp')

        warped_cloth = Image.open(osp.join(self.data_path, warped_cloth_name))
        warped_cloth = transforms.Resize(self.crop_size, interpolation=2)(warped_cloth)
        warped_cloth = self.transform(warped_cloth)
        warped_cloth_mask_name = im_name.replace('image',
                                                 'cloth-warp-mask' if not self.unpaired else 'unpaired-cloth-warp-mask')
        warped_cloth_mask = Image.open(osp.join(self.data_path, warped_cloth_mask_name))
        warped_cloth_mask = transforms.Resize(self.crop_size, interpolation=transforms.InterpolationMode.NEAREST) \
            (warped_cloth_mask)
        warped_cloth_mask = self.to_tensor(warped_cloth_mask)
        warped_cloth = warped_cloth * warped_cloth_mask

        #print(f"gt_cloth_mask min:{torch.min(gt_cloth_mask)}, max:{torch.max(gt_cloth_mask)}")
        #print(f"warp_cloth_mask min:{torch.min(warped_cloth_mask)}, max:{torch.max(warped_cloth_mask)}")

        inpaint_mask = torch.logical_or(gt_cloth_mask, warped_cloth_mask).float() # 将模特身上的衣服和目标衣服预变形的mask结合，最大可能保留人的一些关键特征(胳膊)
        #print(f"warp_cloth_mask min:{torch.min(inpaint_mask)}, max:{torch.max(inpaint_mask)}")

        inpaint_feature = (1 - inpaint_mask) * im


        #--------------------------------------------------------
        inpaint = (1-inpaint_mask) * im + inpaint_mask * warped_cloth
        #inpaint_cloth = cloth_mask * warped_cloth


        #print(f"cloth shape:{warped_cloth.shape}, warp cloth_mask shape:{warped_cloth_mask.shape}, cloth_mask shape:{cloth_mask.shape}")
        #align_warp_cloth, align_warp_cloth_mask =align_to_mask(warped_cloth, warped_cloth_mask, cloth_mask)

        # target warp cloth

        # sum 57651.0
        # sum 28791.7421875
        """
        # load openpose
        pose_name = im_name.replace('image','openpose_json').replace('.jpg', '_keypoints.json')
        with open(osp.join(self.data_path, pose_name), 'r') as f:
                pose_label = json.load(f)
                pose_data = pose_label['people'][0]['pose_keypoints_2d']
                pose_data = np.array(pose_data)
                pose_data = pose_data.reshape((-1, 3))[:, :2] # x坐标、y坐标和置信度，取前两个

                # rescale keypoints on the base of height and width
                pose_data[:, 0] = pose_data[:, 0] * (self.fine_width / 768)
                pose_data[:, 1] = pose_data[:, 1] * (self.fine_height / 1024)

        pose_mapping = get_coco_body25_mapping()
        point_num = len(pose_mapping)
        d = []
        for idx in range(point_num):
            ux = pose_data[pose_mapping[idx], 0]  # / (192)
            uy = (pose_data[pose_mapping[idx], 1])  # / (256)

            # scale posemap points
            px = ux  # * self.width
            py = uy  # * self.height

            d.append(kpoint_to_heatmap(np.array([px, py]), (self.fine_height, self.fine_width), 9))

        openpose_map =torch.stack(d) #[chw] 数据都是[0,1]
        openpose_map = (openpose_map > 0.5).float()
        #print(f"openpose_map.shape {openpose_map.shape},max {openpose_map.max()}")
        """
        #load openpose_img
        openpose_name = im_name.replace('image','openpose_img').replace('.jpg', '_rendered.png')
        openpose_img = Image.open(osp.join(self.data_path, openpose_name)).convert('RGB')
        openpose_img = transforms.Resize((int(self.fine_height/2), int(self.fine_height/2)), interpolation=2)(openpose_img)
        openpose_img = self.transform(openpose_img)

        """
        # load densepose_img
        densepose_name = im_name.replace('image','image-densepose')
        densepose_img = Image.open(osp.join(self.data_path, densepose_name)).convert('RGB')
        densepose_img = transforms.Resize(self.crop_size, interpolation=2)(densepose_img)
        densepose_img = self.transform(densepose_img)

        # load parse_imag
        parse_name = im_name.replace('image','image-parse-v3').replace('.jpg', '.png')
        parse_img = Image.open(osp.join(self.data_path, parse_name)).convert('RGB')
        parse_img = transforms.Resize(self.crop_size, interpolation=2)(parse_img)
        parse_img = self.transform(parse_img)
        """
        # load captions
        cloth_captions = ''
        #human_captions = ''
        #text = ''
        if self.cloth_caption_folder is not None:
            cloth_captions = self.cloth_captions_dict[c_name.split('_')[0]]
            # take a random caption if there are multiple
            if self.datamode == 'train':
                random.shuffle(cloth_captions)
            cloth_captions = ", ".join(cloth_captions)

        #if self.spatial_caption_folder is not None:
        #    text = self.spatial_captions_dict[im_name.split('_')[0]]
       

        



        #show('parse',(parse_img.permute(1,2,0)+1)/2)
        #show('image',(im.permute(1,2,0)+1)/2)

        #print(f"file_name:{self.im_names[index]}")
        result = {
            "ref_human": im,
            "ref_cloth":cloth,
            "warp_cloth":warped_cloth,
            #--------------------------------
            #"warp_mask":warped_cloth_mask,
            #"gt_cloth_mask":gt_cloth_mask,
            #"inpaint_mask":inpaint_mask,
            #"inpaint_feature":inpaint_feature,
            ## "inpaint":inpaint, # 因为我要预测mask，所以最终的inpaint得到肯定是在get_learned_condaition里面得到
            #----------------------------------
            #"human":im,
            "caption": "upper garment" ,#cloth_captions,
            #"human_captions":human_captions,
            #"text":text,
            # 训练时以下才需要
            "agn_mask": agn_mask,
            "cloth_mask": cloth_mask,
            #"openpose_map": openpose_map, # 值>0.5则是1
            'pose_img': openpose_img,
            #"densepose_img":densepose_img,
            #"parse_img":parse_img,
            "im_name": self.im_names[index],
        }
        return result

    def __len__(self):
        return len(self.im_names)



if __name__ == '__main__':
    dataset = CPDataset('DATA/vitonhd/VITON-HD', 512, mode='test', unpaired=True, caption_folder='captions.json')
    loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)
    for data in loader:
        for i in data.keys():
            print(f"{i}:{data[i].shape if not isinstance(data[i], list) else data[i]}")
            if not isinstance(data[i], list):
                if data[i].shape[1] > 3:
                    for j in range(data[i].shape[1]):
                        show(f"{i}-{j}", ((data[i][:,j,:,:].unsqueeze(3)+1)/2)[0])
                else:
                    show(i, ((data[i].permute(0,2,3,1)+1)/2)[0])

        break
