# coding=utf-8
import os

import PIL
import cv2
import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from PIL import Image, ImageDraw
import json

import random
import os.path as osp
import numpy as np
from torch.utils.data import DataLoader
#from datasets.posemap import get_coco_body25_mapping,kpoint_to_heatmap

import matplotlib.pyplot as plt

def show(title, array):
    plt.title(title)
    plt.imshow(array)
    plt.show()



class CPDataset(data.Dataset):
    """
        Dataset for CP-VTON.
    """

    def __init__(self, dataroot, image_size=512, mode='train',  unpaired=False,
                 caption_folder=None
                 ):
        super(CPDataset, self).__init__()
        # base setting
        self.root = dataroot
        self.unpaired = unpaired
        self.datamode = mode  # train or test or self-defined
        self.data_list = mode + '_pairs.txt'
        self.fine_height = image_size
        #self.fine_width = int(image_size / 256 * 192)
        #self.semantic_nc = semantic_nc
        self.data_path = osp.join(dataroot, mode)
        self.crop_size = (self.fine_height, self.fine_height)
        self.to_tensor = transforms.ToTensor()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        


        #assert not (self.datamode == "train" and self.unpaired), f"train must use paired dataset"  暂时先删除，真正训练需要
       
        self.cloth_caption_folder = caption_folder
        #self.spatial_caption_folder = spatial_caption_folder
        if self.cloth_caption_folder is not None:
            #print(osp.join(dataroot, self.cloth_caption_folder))
            with open(osp.join(dataroot, self.cloth_caption_folder), 'r') as f:
                # self.captions_dict = json.load(f)['items']
                self.cloth_captions_dict = json.load(f)
        """
        if self.spatial_caption_folder is not None:
            with open(osp(dataroot, self.spatial_caption_folder), 'r') as f:
                lines = f.readlines()
                self.spatial_captions_dict = {}
                for line in lines:
                    # 去掉行末的换行符
                    line = line.strip()
                    # 拆分键值对
                    key, value = line.split(':', 1)
                    # 将键值对存储在字典中
                    self.spatial_captions_dict[key] = value
        """
        # load data list
        im_names = []
        c_names = []
        with open(osp.join(dataroot, self.data_list), 'r') as f:
            for line in f.readlines():
                #print(f"line : {line}")
                im_name, c_name = line.strip().split()
                im_names.append(im_name)
                c_names.append(c_name)

        self.im_names = im_names
        self.c_names = dict()
        self.c_names['paired'] = im_names
        self.c_names['unpaired'] = c_names

    def name(self):
        return "CPDataset"


    def __getitem__(self, index):
        im_name = self.im_names[index]
        im_name = 'image/' + im_name
        if self.unpaired:
            key = 'unpaired'
        else:
            key = 'paired'


        # load cloth
        c_name= self.c_names[key][index]
        cloth = Image.open(osp.join(self.data_path, 'cloth', c_name)).convert('RGB')
        cloth = transforms.Resize(self.crop_size, interpolation=2)(cloth)
        cloth = self.transform(cloth)  # [-1,1]

        # load image
        im_pil_big = Image.open(osp.join(self.data_path, im_name)).convert('RGB')
        im_pil = transforms.Resize(self.crop_size, interpolation=2)(im_pil_big)
        im = self.transform(im_pil)

        # load agn_mask
        agn_mask_name = im_name.replace('image', 'agnostic-mask').replace('.jpg', '_mask.png')
        agn_mask_big = Image.open(osp.join(self.data_path, agn_mask_name)).convert('L')
        agn_mask = transforms.Resize(self.crop_size, interpolation=transforms.InterpolationMode.NEAREST)(agn_mask_big)
        agn_mask = self.to_tensor(agn_mask)
        agn_mask = (agn_mask > 0.5).float()
        #print(f"agn_mask sum : {agn_mask.shape}, sum {agn_mask.sum()}")

        # load cloth_mask
        #cloth_mask_name = c_name.replace('image','cloth-mask')  ##### 这里是cloth-warp-mask
        cloth_mask = Image.open(osp.join(self.data_path,'cloth-mask', c_name)).convert('L')
        cloth_mask = transforms.Resize(self.crop_size, interpolation=transforms.InterpolationMode.NEAREST) \
            (cloth_mask)

        cloth_mask = self.to_tensor(cloth_mask)
        cloth_mask = (cloth_mask > 0.5).float()
        #print(f"warped_cloth:{warped_cloth_mask.shape}, sum {warped_cloth_mask.sum()}")

        # sum 57651.0
        # sum 28791.7421875
        """
        # load openpose
        pose_name = im_name.replace('image','openpose_json').replace('.jpg', '_keypoints.json')
        with open(osp.join(self.data_path, pose_name), 'r') as f:
                pose_label = json.load(f)
                pose_data = pose_label['people'][0]['pose_keypoints_2d']
                pose_data = np.array(pose_data)
                pose_data = pose_data.reshape((-1, 3))[:, :2] # x坐标、y坐标和置信度，取前两个

                # rescale keypoints on the base of height and width
                pose_data[:, 0] = pose_data[:, 0] * (self.fine_width / 768)
                pose_data[:, 1] = pose_data[:, 1] * (self.fine_height / 1024)

        pose_mapping = get_coco_body25_mapping()
        point_num = len(pose_mapping)
        d = []
        for idx in range(point_num):
            ux = pose_data[pose_mapping[idx], 0]  # / (192)
            uy = (pose_data[pose_mapping[idx], 1])  # / (256)

            # scale posemap points
            px = ux  # * self.width
            py = uy  # * self.height

            d.append(kpoint_to_heatmap(np.array([px, py]), (self.fine_height, self.fine_width), 9))

        openpose_map =torch.stack(d) #[chw] 数据都是[0,1]
        openpose_map = (openpose_map > 0.5).float()
        #print(f"openpose_map.shape {openpose_map.shape},max {openpose_map.max()}")
        """
        #load openpose_img
        openpose_name = im_name.replace('image','openpose_img').replace('.jpg', '_rendered.png')
        openpose_img = Image.open(osp.join(self.data_path, openpose_name)).convert('RGB')
        openpose_img = transforms.Resize((int(self.fine_height/2), int(self.fine_height/2)), interpolation=2)(openpose_img)
        openpose_img = self.transform(openpose_img)

        """
        # load densepose_img
        densepose_name = im_name.replace('image','image-densepose')
        densepose_img = Image.open(osp.join(self.data_path, densepose_name)).convert('RGB')
        densepose_img = transforms.Resize(self.crop_size, interpolation=2)(densepose_img)
        densepose_img = self.transform(densepose_img)

        # load parse_imag
        parse_name = im_name.replace('image','image-parse-v3').replace('.jpg', '.png')
        parse_img = Image.open(osp.join(self.data_path, parse_name)).convert('RGB')
        parse_img = transforms.Resize(self.crop_size, interpolation=2)(parse_img)
        parse_img = self.transform(parse_img)
        """
        # load captions
        cloth_captions = ''
        #human_captions = ''
        #text = ''
        if self.cloth_caption_folder is not None:
            cloth_captions = self.cloth_captions_dict[c_name.split('_')[0]]
            # take a random caption if there are multiple
            if self.datamode == 'train':
                random.shuffle(cloth_captions)
            cloth_captions = ", ".join(cloth_captions)

        #if self.spatial_caption_folder is not None:
        #    text = self.spatial_captions_dict[im_name.split('_')[0]]
       

        



        #show('parse',(parse_img.permute(1,2,0)+1)/2)
        #show('image',(im.permute(1,2,0)+1)/2)

        #print(f"file_name:{self.im_names[index]}")
        result = {
            "ref_human": im,
            "ref_cloth":cloth,
            #"human":im,
            "caption": "upper garment" ,#cloth_captions,
            #"human_captions":human_captions,
            #"text":text,
            # 训练时以下才需要
            "agn_mask": agn_mask,
            "cloth_mask": cloth_mask,
            #"openpose_map": openpose_map, # 值>0.5则是1
            'pose_img': openpose_img,
            #"densepose_img":densepose_img,
            #"parse_img":parse_img,
            "im_name": self.im_names[index],
        }
        return result

    def __len__(self):
        return len(self.im_names)



if __name__ == '__main__':
    dataset = CPDataset('./DATA/VITON-HD', 512, mode='train', unpaired=True, caption_folder='captions.json')
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
    for data in loader:
        for i in data.keys():
            print(f"{i}:{data[i].shape if not isinstance(data[i], list) else data[i]}")
            if not isinstance(data[i], list):
                if data[i].shape[1] > 3:
                    for j in range(data[i].shape[1]):
                        show(f"{i}-{j}", ((data[i][:,j,:,:].unsqueeze(3)+1)/2)[0])
                else:
                    show(i, ((data[i].permute(0,2,3,1)+1)/2)[0])

        break
