import os
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
import torch.nn.functional as F
from util import iuv2smpluv,TexTransformer,Colorize
from PIL import Image
import pandas as pd
import torch
import cv2
import pose_utils

LABEL=['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat','Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm','Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
  
def dataframe2pairlist(data):
   lst=[]
   for i in data.index:
       name1=data.iloc[i]['from']
       name2=data.iloc[i]['to']
       lst.append((name1,name2))
   return lst
   
def uvflip(uv_ori):
    uv=(uv_ori.clone()+1)/2*255
    mask_left=(torch.ge(uv[1],110)*torch.le(uv[0],160*0.5)).float()
    mask_right=(torch.ge(uv[1],110)*torch.ge(uv[0],160*0.5)*torch.le(uv[0],159)).float()
    uv[0]=(160*0.5+uv[0])*mask_left+(uv[0]-160*0.5)*mask_right+uv[0]*(1-mask_left-mask_right)
    uv=uv/255*2-1
    return uv
    
def get_uv_stats(uv,mask):
    data=torch.masked_select(uv[0].flatten(),mask.flatten())
    #print(data,len(data))
    std,mean=torch.std_mean(data,0,True)
    return std
    
def new_uv(uv):
    mask=torch.le(uv[0],150/255*2-1)
    std1=get_uv_stats(uv,mask)
    uv2=uvflip(uv)
    std2=get_uv_stats(uv2,mask)
    #print(std1,std2)
    if std1<=std2:
        return uv
    else:
        return uv2
    
def parse2bbox(parse,margin_rate=0.2,mode='train'):
    index=np.nonzero(parse)
    r1,r2,c1,c2=np.min(index[0]),np.max(index[0]),np.min(index[1]),np.max(index[1])
    
    margin_r=int((r2-r1)*margin_rate)
    r1=max(r1-margin_r,0)
    r2=min(r2+margin_r,parse.shape[0])
    
    margin_c=int((c2-c1)*margin_rate)
    c1=max(c1-margin_c,0)
    c2=min(c2+margin_c,parse.shape[1])
    
    
    parsing=torch.zeros((1,parse.shape[0],parse.shape[1]),dtype=torch.float32)
    parsing[:,r1:r2,c1:c2]=1
    return parsing
    
class MarketDataset(data.Dataset):
    def __init__(self,data_root='data/market',mode='test'):
       self.tex_transformer=TexTransformer(os.path.join(data_root,'smpltexmap.npy'))
       transform_list=[]
       transform_list.append(transforms.ToTensor())
       transform_list.append(transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)))
       self.trans = transforms.Compose(transform_list)
       if mode=='train':
           self.image_dir=os.path.join(data_root,'bounding_box_train/')
       else:
           self.image_dir=os.path.join(data_root,'bounding_box_test/')
       
       if mode=='train':       
           pair= pd.read_csv(os.path.join(data_root,'market-pairs-train.csv'),sep=',')
           self.pair=dataframe2pairlist(pair)
       else:
           pair=pd.read_csv(os.path.join(data_root,'market-pairs-test.csv'),sep=',')
           self.pair=dataframe2pairlist(pair)
       
      
       self.uv_dir=os.path.join(data_root,'uvmap/')
       self.parse_dir=os.path.join(data_root,'seg/')
       self.color_labelmap=Colorize()
       
       if mode=='test':      
           self.annotation_file = pd.read_csv(os.path.join(data_root,'market-annotation-test.csv'),sep=':')
       else:
           self.annotation_file = pd.read_csv(os.path.join(data_root,'market-annotation-train.csv'),sep=':')
       self.annotation_file = self.annotation_file.set_index('name')
       self.mode=mode
       
       
    def __len__(self):
        return len(self.pair)
        
    def __getitem__(self, index):
        P1_name,P2_name=self.pair[index]
        P1_path = os.path.join(self.image_dir, P1_name) # 
        P2_path = os.path.join(self.image_dir, P2_name)#
        P1_img = Image.open(P1_path).convert('RGB')
        P2_img = Image.open(P2_path).convert('RGB')
        P1 = self.trans(P1_img)
        P2 = self.trans(P2_img)
        
               
        
        UV1_name,UV2_name=P1_name.replace('.jpg','.png'),P2_name.replace('.jpg','.png')
        UV1_path = os.path.join(self.uv_dir, UV1_name) # person 1       
        UV2_path = os.path.join(self.uv_dir, UV2_name)
        if not os.path.exists(UV1_path):
            UV1=-10*torch.ones((2,128,64))
        else:
            uvi1=np.array(Image.open(UV1_path))
            UV1=iuv2smpluv(uvi1,self.tex_transformer)
            UV1=torch.from_numpy(UV1).permute(2,0,1)
            UV1=new_uv(UV1)
            UV1[torch.isnan(UV1)] = -10
        if not os.path.exists(UV2_path):
            UV2=-10*torch.ones((2,128,64))
        else:
            uvi2=np.array(Image.open(UV2_path))
            UV2=iuv2smpluv(uvi2,self.tex_transformer)
            UV2=torch.from_numpy(UV2).permute(2,0,1)
            UV2=new_uv(UV2)
            UV2[torch.isnan(UV2)] = -10

        A_row = self.annotation_file.loc[P1_name.replace('.png','.jpg')]
        A_kp_array = pose_utils.load_pose_cords_from_strings(A_row['keypoints_y'], A_row['keypoints_x'])
        joints1=pose_utils.cords_to_map(A_kp_array,(128,64))  
        joints1=np.transpose(joints1,axes=[2,0,1])
        UV1=torch.from_numpy(np.concatenate([UV1,joints1],0)).float()
        
        
        B_row = self.annotation_file.loc[P2_name.replace('.png','.jpg')]
        B_kp_array = pose_utils.load_pose_cords_from_strings(B_row['keypoints_y'], B_row['keypoints_x'])
        joints2=pose_utils.cords_to_map(B_kp_array,(128,64))      
        joints2=np.transpose(joints2,axes=[2,0,1])
        UV2=torch.from_numpy(np.concatenate([UV2,joints2],0)).float()
        ssim_mask2=pose_utils.create_image_mask(B_row)
        
        Parse1_name=P1_name[:-4]+'.png'
        Parse1_path =os.path.join(self.parse_dir, Parse1_name)
        Parse1arr=np.array(Image.open(Parse1_path))
 
        parsing_map=self.color_labelmap(Parse1arr)
    
        mask=np.not_equal(Parse1arr,0).astype(np.uint8)
        mask_bg=1-parse2bbox(Parse1arr,margin_rate=0.05,mode=self.mode)#torch.from_numpy(cv2.dilate(mask, np.ones((10,10),np.uint8), iterations=1)).float()   
        seg=torch.from_numpy(mask[None]).float()
        bg=P1*mask_bg
        
        out_name=P1_name+'to'+P2_name+'.png'
        return P1,P2,seg,bg,UV1,UV2,out_name,parsing_map,ssim_mask2
    
def permute_images(inputs,parsing,test=False):

    inputs=inputs.clone().detach()
  
    b,c,h,w=inputs.shape
    
    inputs=inputs*parsing 
    if test: 
       return inputs
    
    res=8
    nblocks1=h//res
    nblocks2=w//res
    inputs=inputs.view(b,c,nblocks1,res,nblocks2,res)
    
    x=torch.randperm(nblocks1)
    y=torch.randperm(nblocks2)
    inputs=inputs[:,:,x]
    inputs=inputs[:,:,:,:,y]
    
    mratio_tex=0.2
    mratio_pose=0.5
    rand_mask=torch.greater(torch.rand(b,1,nblocks1,1,nblocks2,1),mratio_tex).type_as(inputs)
    inputs=inputs*rand_mask
   
    rand_mask=torch.greater(torch.rand(b,1,nblocks1,1,nblocks2,1),mratio_pose).type_as(inputs)
    inputs[:,3+3:]=inputs[:,3+3:]*rand_mask
    
    #keypoints masking
    rand_mask=torch.greater(torch.rand(b,18,1,1,1,1),mratio_pose).type_as(inputs)#0.5
    inputs[:,3+3+3+2:]=inputs[:,3+3+3+2:]*rand_mask

    
    inputs=inputs.view(b,c,h,w)

    return inputs
    
if __name__=='__main__':

     from torch.utils.data import DataLoader
     from util import tensor2im
     t=MarketDataset(mode='test')
     d = DataLoader(t, batch_size=4, shuffle=False) 
     for P1,P2,seg,bg,UV1,UV2,out_name,parsing_map,ssim_mask2 in d:
         x=torch.cat([P1,P2,seg.repeat(1,3,1,1),bg,UV1[:,0:1].repeat(1,3,1,1),UV2[:,0:1].repeat(1,3,1,1),parsing_map],-1)
         for j in range(len(x)):
             tensor2im(x[j],'temp/'+out_name[j])
         assert 1==0
    
