
import torch
from torch import nn
from torch.nn import functional as F
from my_model_deepf import Discriminator,ImageGenerator
import cv2
import numpy as np
from util import patch_gram_matrix,VGG19

class PT2(nn.Module):
    def __init__(self,img_size=256,ngf=32,norm_layer=None,loss_weights=[2.0,5.0,0.5,1.5]): #1:4,2:8,3:16
         super().__init__() 
         self.generator=ImageGenerator(img_size=img_size,ngf=ngf,norm_layer=norm_layer)
         self.discriminator=Discriminator(img_size=img_size, ngf=ngf//2)#(0,256,3)
         self.discriminator_pose=Discriminator(img_size=img_size, ngf=ngf//2,first_in_channel=3+2)
         self.loss_weights=loss_weights
         
         self.g_pars=[p for p in self.generator.parameters() if p.requires_grad]
         self.d_pars=[p for p in self.discriminator.parameters() if p.requires_grad]+[p for p in self.discriminator_pose.parameters() if p.requires_grad]#+[p for p in self.discriminator_seg.parameters() if p.requires_grad]
      
         self.optimizerD = torch.optim.AdamW( self.d_pars, lr=1e-4, betas=(0.5, 0.999))
         self.optimizerG = torch.optim.AdamW( self.g_pars, lr=1e-3, betas=(0.5, 0.999))
         
         self.vgg=VGG19()
        
    def forward(self,permuted_img,pose,z,bg):
    
         img,seg=self.generator(permuted_img,pose,z,bg)
         
         return img,seg
        
   
        
    def backwardD(self,gen,pose,real):   
        for p in self.g_pars:
            p.requires_grad=False
        for p in self.d_pars:
            p.requires_grad=True
        self.optimizerD.zero_grad()
        
        gen=gen.detach()
        score_f=self.discriminator_pose(torch.cat([gen,pose],1))
        real_tmp = torch.cat([real,pose],1).detach()
        score_r=self.discriminator_pose(real_tmp)
        dis_loss=self.loss_weights[0]*0.25*(torch.square(score_f)+torch.square(1-score_r)).mean()#(torch.nn.functional.softplus(score_f2)+torch.nn.functional.softplus(-score_r2)).mean()#
        
        
        score_f=self.discriminator(gen)
        score_r=self.discriminator(real)
        dis_loss+=self.loss_weights[0]*0.25*(torch.square(score_f)+torch.square(1-score_r)).mean()#(torch.nn.functional.softplus(score_f2)+torch.nn.functional.softplus(-score_r2)).mean()#
        
        dis_loss.backward()
        self.optimizerD.step()
        return dis_loss
        
        
    def backwardG(self,gen,real,pose):
         for p in self.g_pars:
            p.requires_grad=True
         for p in self.d_pars:
            p.requires_grad=False
         self.optimizerG.zero_grad()
                          
         score_f=self.discriminator(gen)
         gen_loss=self.loss_weights[0]*0.5*(torch.square(1-score_f).mean())
         
         score_f=self.discriminator_pose(torch.cat([gen,pose],1))
         gen_loss+=self.loss_weights[0]*0.5*(torch.square(1-score_f).mean())
         
         rec_loss=torch.abs(real-gen).mean()
         rec_loss=self.loss_weights[1]*rec_loss.mean()
         
        
         feats=self.vgg(real)
         feats2=self.vgg(gen)
         p_loss=0.0
         for k,layer in enumerate(['relu1_1','relu2_1','relu3_1','relu4_1','relu5_1']):#(['relu1_2']):#
             p_loss+=torch.abs(feats[layer]-feats2[layer]).mean()
         
         style_loss=0.0
         for k,layer in enumerate(['relu2_2','relu3_4','relu4_4','relu5_2']):
            seg=torch.ones((gen.shape[0],1,gen.shape[2]//2**(k+1),gen.shape[3]//2**(k+1))).cuda()
            feats_g=patch_gram_matrix(feats[layer],seg)
            feats2_g=patch_gram_matrix(feats2[layer],seg)
            style_loss+=torch.abs(feats_g-feats2_g).mean()
         
         vgg_loss=self.loss_weights[2]*p_loss+self.loss_weights[3]*style_loss
        
                  
         (gen_loss+vgg_loss+rec_loss).backward()#

         self.optimizerG.step()
         return gen_loss,rec_loss,vgg_loss
         
if __name__=='__main__':
    network=PT2().cuda()
    inputs=torch.randn(2,11,256,176).cuda()
    img=torch.randn(2,3,256,176).cuda()
    bg=torch.randn(2,3,256,176).cuda()
    pose=torch.randn(2,2,256,176).cuda()
    z=torch.randn(2,256).cuda()
    gen,_=network(inputs,pose,z,bg)
    dis_loss=network.backwardD(gen,pose,img)
    gen_loss,rec_loss,p_loss=network.backwardG(gen,img,pose)