import torch
from torch.utils.data import DataLoader
from torch import nn
from network_market import PT2
from dataloader_market import MarketDataset,permute_images
from util import tensor2im
import numpy as np
import cv2
import torch.nn.functional as F
import argparse

import warnings
warnings.filterwarnings("ignore")

num_epochs=50
batch_size=4
start_epoch=0

parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('--result_dir', '--result_dir', default='result/market/', type=str, help='directory for visulizations')
parser.add_argument('--model_dir', '--model_dir', default='model/market/', type=str, help='directory for models')
parser.add_argument('--save_every', '--save_every', default=5, type=int, help='save every x epochs')
parser.add_argument('--checkpoint_path', '--checkpoint_path', default=None, type=str, help='')
parser.add_argument('--checkpoint_ema_path', '--checkpoint_ema_path', default=None, type=str, help='')
args = parser.parse_args()


result_dir=args.result_dir
model_dir=args.model_dir
checkpoint_path=args.checkpoint_path
checkpoint_ema_path=args.checkpoint_ema_path

random_seed = 1234 # or any of your favorite number 
torch.manual_seed(random_seed)
np.random.seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.benchmark = True

def train(network,dataloader,epoch,ema_helper):
      network.train()
      loss_dict={'gen_loss':0,'rec_loss':0,'dis_loss':0,'total_G_loss':0,'mse':0,'p_loss':0}
      for i,(P1,P2,seg,bg,UV1,UV2,name,parsing_map) in enumerate(dataloader): 
             
          bg=bg.cuda()
          z=torch.normal(0,1,size=(len(P1),128))
          inputs=torch.cat([P1,parsing_map,parsing_map,UV1],1)
          permuted_imgs=permute_images(inputs.cuda(),seg.cuda(),test=False)
          gen_img,_=network(permuted_imgs,UV1.cuda(),z.cuda(),bg)
          gen_img2,_=network(permuted_imgs,UV2.cuda(),z.cuda(),bg)
        
          
          dis_loss=network.backwardD(gen_img,UV1.cuda(),P1.cuda())
          gen_loss,rec_loss,p_loss=network.backwardG(gen_img,P1.cuda(),UV1.cuda())
          
      
          ema_helper.update(network.generator)
        
          loss_dict['gen_loss']+=gen_loss.cpu().item()
          loss_dict['rec_loss']+=im_loss.cpu().item()
          loss_dict['dis_loss']+=dis_loss.cpu().item()
          loss_dict['p_loss']+=p_loss.cpu().item()
          loss_dict['total_G_loss']=loss_dict['gen_loss']+loss_dict['rec_loss']+loss_dict['p_loss']
          
          loss_dict['mse']+=torch.square(gen_img2-P2.cuda()).mean().cpu().item()
          if i%100==0:
            print('---------loss at epoch %d iteration %d----------'%(epoch,i))
            for k in loss_dict.keys():
               print(k,loss_dict[k]/(i+1))
               
          
def test(epoch,network,dataloader,ema_helper,result_dir='result/'):
     print('====================test==================')
     ema_helper.ema(network.generator)
     network.eval()
     mse_accu=[]
     for i,(P1,P2,seg,bg,UV1,UV2,name,parsing_map) in enumerate(dataloader):
         with torch.no_grad():
            permuted_imgs=permute_images(torch.cat([P1,parsing_map,parsing_map,UV1],1).cuda(),seg.cuda(),test=True)#
            bg=bg.cuda()
            z=torch.normal(0,1,size=(len(P1),128))
            gen_img,gen_seg=network(permuted_imgs1,UV2.cuda(),z.cuda(),bg)
            gen_img2,_=network(permuted_imgs1,UV1.cuda(),z.cuda(),bg)
            mse=torch.square(gen_img-P2.cuda()).mean().cpu().item()
            mse_accu.append(mse)
         if i%10==0:
             segmap=gen_seg.repeat(1,3,1,1)
             r=torch.concat([segmap,P1.cuda(),P2.cuda(),gen_img2,gen_img],dim=3)
             for j in range(len(P1)):
                 tensor2im(r[j],os.path.join(result_dir,name[j]))
         
     network.load_state_dict(torch.load(os.path.join(model_dir,str(epoch)+'.ckpt'),map_location=torch.device('cuda'))['network_state_dict'])   
     print('test mse,',np.mean(mse_accu))
     print('=========================================')
    
             
def save_model(network,epoch,path):
     param_dict={'epoch':epoch+1,'network_state_dict':network.state_dict()}
     torch.save(param_dict,path)
    
train_data=MarketDataset(mode='train')
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True,num_workers=2) 
test_data=MarketDataset(mode='test')
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False,num_workers=2) 


network=PT2(norm_layer=None).cuda()

from util import EMAHelper
ema_helper = EMAHelper(mu=0.999)
ema_helper.register(network.generator)



if checkpoint_path is not None:
    checkpoint = torch.load(checkpoint_path,map_location=torch.device('cuda'))
    network.load_state_dict(checkpoint['network_state_dict'])
    start_epoch=checkpoint['epoch']
    print('model loaded from %s'%checkpoint_path)
    
    
if checkpoint_ema_path is not None:  
    checkpoint = torch.load(checkpoint_ema_path,map_location=torch.device('cuda'))
    ema_helper.load_state_dict(checkpoint['network_state_dict'])
    print('ema model loaded from %s'%checkpoint_ema_path)

scheduler1=torch.optim.lr_scheduler.StepLR(network.optimizerG, 1,gamma=0.2)
scheduler2=torch.optim.lr_scheduler.StepLR(network.optimizerD, 1,gamma=0.2)

    
for epoch in range(start_epoch,num_epochs+1):
     train(network,train_dataloader,epoch,ema_helper)
     if epoch%args.save_every==0:
        save_model(network,epoch,os.path.join(model_dir,str(epoch)+'.ckpt'))
        save_model(ema_helper,epoch,os.path.join(model_dir,str(epoch)+'ema.ckpt'))
    if epoch==1:
        scheduler1.step()
        scheduler2.step()  
        
    
