import torch
from torch.utils.data import DataLoader
from torch import nn
from util import tensor2im,compare_ssim,FID,LPIPS_calcuator,IS
import numpy as np
import cv2
import torch.nn.functional as F
import argparse
from PIL import Image
import os 

import warnings
warnings.filterwarnings("ignore")


batch_size=4
    
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--dataset', '--dataset', default='deepf', type=str, help='which dataset to use')
parser.add_argument('--result_dir', '--result_dir', default='result/deepf/', type=str, help='directory for visulizations')
parser.add_argument('--checkpoint_ema_path', '--checkpoint_ema_path', default='model/deepf/ema.ckpt', type=str, help='')
args = parser.parse_args()

DATASET=args.dataset
result_dir=args.result_dir
checkpoint_ema_path=args.checkpoint_ema_path

if DATASET=='deepf':
    from network_deepf import PT2
    from dataloader_deepf import DeepFashionDataset,permute_images
else:
    from network_market import PT2
    from dataloader_market import MarketDataset,permute_images

random_seed = 1234 # or any of your favorite number 
torch.manual_seed(random_seed)
np.random.seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.benchmark = True

def addBounding(image, bound=40):
    if len(image.shape)==3:
        h, w, c = image.shape
        image_bound = np.ones((h, w+bound*2, c),dtype=np.float32)*255
        image_bound[:, bound:bound+w] = image
    else:
        n,c,h, w= image.shape
        image_bound = np.ones((n,c,h, w+bound*2),dtype=np.float32)*255
        image_bound[:,:,:, bound:bound+w] = image
    return image_bound
    
def test_fid(network,dataloader,ema_helper,result_dir):
    tool=FID()
    tool2=IS()
    lpips_calcuator=LPIPS_calcuator()
    gen_imgs=[] 
    gt_imgs=[]  
    lpips_accu=[]
    ssim_accu=[]
    psnr_accu=[]
    ema_helper.ema(network.generator)
    network.eval()
    for i,(P1,P2,seg,bg,UV1,UV2,name,parsing_map) in enumerate(dataloader):
         with torch.no_grad():
           permuted_imgs=permute_images(torch.cat([P1,parsing_map,parsing_map,UV1],1).cuda(),seg.cuda(),test=True)
           bg=bg.cuda()
           z=torch.normal(0,1,size=(len(P1),256))
           gen_img,_=network(permuted_imgs,UV2.cuda(),z.cuda(),bg)
           for j in range(len(P1)):
                tensor2im(torch.cat([P1[j].cuda(),P2[j].cuda(),gen_img[j]],-1),os.path.join(result_dir,name[j]))
           
           gt=addBounding((P2.numpy()+1)/2*255)
           gen=addBounding((gen_img.cpu().numpy()+1)/2*255)
           gen_imgs.append(gen)
           gt_imgs.append(gt)
           
           lpips_accu.append(lpips_calcuator(gt/255*2-1,gen/255*2-1))
           for j in range(len(P1)):
              ssim_accu.append(compare_ssim(np.transpose(gen[j],axes=[1,2,0]),np.transpose(gt[j],axes=[1,2,0])))
              psnr_accu.append(compare_psnr(np.transpose(gen[j],axes=[1,2,0]),np.transpose(gt[j],axes=[1,2,0])))
           if i%100==0:
               print(i) 

    lpips_accu=np.concatenate(lpips_accu,0)
    print('lpips',np.mean(lpips_accu),np.std(lpips_accu))
    
    print('ssim',np.mean(ssim_accu),np.std(ssim_accu))
    
    gen_imgs=np.concatenate(gen_imgs,0)
    gt_imgs=np.concatenate(gt_imgs,0)
    
    is_score,is_std=tool2(gen_imgs/255,get_std=True)
    print('is',is_score,is_std)
    
    fid=tool(gen_imgs/255,gt_imgs/255)
    print('fid',fid)
    
def test_fid_market(network,dataloader,ema_helper,result_dir):
    tool=FID()
    tool2=IS()
    lpips_calcuator=LPIPS_calcuator()
    gen_imgs=[] 
    gt_imgs=[]  
    lpips_accu=[]
    mask_lpips_accu=[]
    ssim_accu=[]
    masked_ssim_accu=[]
    ema_helper.ema(network.generator)
    network.eval()
    for i,(P1,P2,seg,bg,UV1,UV2,name,parsing_map,ssim_mask) in enumerate(dataloader):
         with torch.no_grad():
           permuted_imgs=permute_images(torch.cat([P1,parsing_map,parsing_map,UV1],1).cuda(),seg.cuda(),test=True)
           bg=bg.cuda()
           z=torch.normal(0,1,size=(len(P1),128))
           gen_img,_=network(permuted_imgs,UV2.cuda(),z.cuda(),bg)
           for j in range(len(P1)):
                tensor2im(torch.cat([P1[j].cuda(),P2[j].cuda(),gen_img[j]],-1),os.path.join(result_dir,name[j]))
           
           gt=(P2.numpy()+1)/2*255
           gen=(gen_img.cpu().numpy()+1)/2*255
           gen_imgs.append(gen)
           gt_imgs.append(gt)
           
           lpips_accu.append(lpips_calcuator(gt/255*2-1,gen/255*2-1))
           
           mask=ssim_mask.numpy()[:,None]
           mask_lpips_accu.append(lpips_calcuator(gt*mask/255*2-1,gen*mask/255*2-1))
           
           for j in range(len(P1)):
                 mask=ssim_mask[j].numpy()[:,:,None]
                 x=np.transpose(gen[j],axes=[1,2,0])
                 y=np.transpose(gt[j],axes=[1,2,0])
                 ssim_accu.append(compare_ssim(x,y))
                 masked_ssim_accu.append(compare_ssim(mask*x,mask*y))
                 all_names.append(name[j])
           if i%100==0:
               print(i)    
   
    lpips_accu=np.concatenate(lpips_accu,0)
    print('lpips',np.mean(lpips_accu))
    
    mask_lpips_accu=np.concatenate(mask_lpips_accu,0)
    print('masked-lpips',np.mean(mask_lpips_accu))
    
    gen_imgs=np.concatenate(gen_imgs,0)
    gt_imgs=np.concatenate(gt_imgs,0)
    
    is_score=tool2(gen_imgs/255)
    print('is',is_score)
    
    fid=tool(gen_imgs/255,gt_imgs/255)
    print('fid',fid)
    
    print('ssim,',np.mean(ssim_accu),np.mean(masked_ssim_accu))
    
if DATASET=='deepf':
    test_data=DeepFashionDataset(mode='test')
else:
    test_data=MarketDataset(mode='test')    
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False,num_workers=2) 


network=PT2(norm_layer=None).cuda()

from util import EMAHelper
ema_helper = EMAHelper(mu=0.999)
ema_helper.register(network.generator)


if checkpoint_ema_path is not None:  
    checkpoint = torch.load(checkpoint_ema_path,map_location=torch.device('cuda'))
    ema_helper.load_state_dict(checkpoint['network_state_dict'])
    print('ema model loaded from %s'%checkpoint_ema_path)
    if DATASET=='deepf':
        test_fid(network,test_dataloader,ema_helper,result_dir)
    else:
        test_fid_market(network,test_dataloader,ema_helper,result_dir)
        
    
