import math
import random
import torch
from torch import nn
from torch.nn import functional as F
from torch_utils.ops import conv2d_resample
from torch_utils.ops import upfirdn2d
import numpy as np
import cv2

def token2feature(x,img_height=128,num_downsamples=2):
    B, N, C = x.shape
    h=img_height//(2**num_downsamples)
    w=N//h
    x = x.permute(0, 2, 1).reshape(B, C, h, w)
    return x


def feature2token(x):
    B, C, H, W = x.shape
    x = x.view(B, C, H*W).transpose(1, 2)
    return x


class EqualConv2d(nn.Module):
    def __init__(
            self, in_channel, out_channel, kernel_size, up=1,down=1,bias=True, resample_filter = [1,3,3,1]
    ):#[1,3,3,1]
        super().__init__()

        self.weight = nn.Parameter(
            torch.randn(out_channel, in_channel, kernel_size, kernel_size)
        )
        self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
        self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))

        self.padding = kernel_size // 2

        if bias:
            self.bias = nn.Parameter(torch.zeros(1,out_channel,1,1))

        else:
            self.bias = None
        self.up=up
        self.down=down
        self.out_channel=out_channel

    def forward(self, input):      
         out = conv2d_resample.conv2d_resample(x=input, w=self.weight * self.scale, f=self.resample_filter, up=self.up, down=self.down,
                                            padding=self.padding)
         if self.bias is not None:
            out=out+self.bias
         return out
        
class EqualLinear(nn.Module):
    def __init__(
            self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1
    ):
        super().__init__()

        self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))

        if bias:
            self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))

        else:
            self.bias = None

        self.scale = (1 / math.sqrt(in_dim)) * lr_mul

    def forward(self, input):
        out = F.linear(
                input, self.weight * self.scale, bias=self.bias 
            )
        return out        

        
class ConvLayer(nn.Module):
    def __init__(
            self,
            in_channel,
            out_channel,
            kernel_size,
            downsample=False,
            upsample=False,
            bias=True,
            resample_filter=[1,3,3,1]
    ):
        super().__init__()

        if downsample:            
            self.model=EqualConv2d(
                in_channel,
                out_channel,
                kernel_size,
                down=2,
                bias=bias,
                resample_filter=resample_filter
            )
            
        elif upsample:
            self.model=EqualConv2d(
                in_channel,
                out_channel,
                kernel_size,
                up=2,
                bias=bias,
                resample_filter=resample_filter
            )
            
        else:
            self.model=EqualConv2d(
                in_channel,
                out_channel,
                kernel_size,
                up=1,
                down=1,
                bias=bias,
                resample_filter=resample_filter
            )
        
        
    def forward(self, input):
       return self.model(input)
       
def gelu(x):
    """Implementation of the gelu activation function.
        For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
        0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    """
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return gelu(x)

        
class TransAttentionBlock(nn.Module):
    def __init__(self, n_embd, n_head, H, W,in_channel=None):
        super().__init__()
        self.att=SelfAttention(n_embd, n_head,in_channel=in_channel)
        self.mlp = nn.Sequential(
            EqualLinear(n_embd, 2 * n_embd),
            GELU(),
            EqualLinear(2*n_embd, n_embd)
            )
        self.norm_layer=nn.InstanceNorm1d(n_embd,affine=True)   
           
    def forward(self,q,k,v,return_att=False,z1=None,z2=None):
        q=feature2token(q)
        k=feature2token(k)
        v=feature2token(v)
        if z1 is not None:
             k=torch.cat([k,z1],1)
        if z2 is not None:
             v=torch.cat([v,z2],1)
        
        x,att =self.att(q,k,v,return_att=return_att)            
        x=x+q
        x = x + self.mlp(x)
        x=token2feature(x)
        if return_att:
            return x,att
        else:
            return x
        
        
       
class SelfAttention(nn.Module):
   def __init__(self, channel, num_heads=8,in_channel=None):
        super().__init__()
        self.num_heads=num_heads
        in_channel=channel if in_channel is None else in_channel
        self.scale=(channel//num_heads) ** -0.5
        self.q=EqualLinear(channel,channel)
        self.k=EqualLinear(in_channel,channel)
        self.v=EqualLinear(channel,channel)
        self.proj=EqualLinear(channel,channel)
        self.softmax = nn.Softmax(dim=-1)
        
        self.norm1=nn.InstanceNorm1d(channel,affine=True)
        self.norm2=nn.InstanceNorm1d(channel,affine=True)
       
   def forward(self, q_pose,k_app,v_app,rel_pos=None,return_att=False):
        B_, N, C = q_pose.shape
        norm_q =self.norm1(q_pose.permute(0,2,1)).permute(0,2,1)#F.normalize(q_pose, p=2.0, dim=-1)
        norm_k =self.norm2(k_app.permute(0,2,1)).permute(0,2,1)#F.normalize(k_app, p=2.0, dim=-1)
        x=v_app
        q = self.q(norm_q).reshape(B_, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k(norm_k).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 3, 1)
        v = self.v(x).view(B_, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        attn = (q @ k) * self.scale
        if rel_pos is not None:
            attn += rel_pos
        attn = self.softmax(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C) #heads,n1,n2;heads,n2,c
        x = self.proj(x)
        if return_att:
          return x,attn
        else:
          return x
       
class ToRGB_wostyle(nn.Module):
    def __init__(self,in_channel,kernel_size,norm=None):
        super().__init__()
        if norm is None:
           self.model=nn.Sequential(
                      nn.ReLU(),
                      ConvLayer(in_channel,3,kernel_size),
                      nn.Tanh()
                      )
        else:
           self.model=nn.Sequential(
                      norm(in_channel),
                      nn.ReLU(),
                      ConvLayer(in_channel,3,kernel_size),
                      nn.Tanh()
                      )
        self.activ=nn.Tanh()
    def forward(self,input):
        return self.model(input)
        
class ResBlock(nn.Module):
    def __init__(self, in_channel, out_channel, downsample=False,upsample=False,norm_layer=nn.BatchNorm2d,activation=nn.LeakyReLU()):
        super().__init__()

        if norm_layer is None:
             self.model= nn.Sequential(activation, ConvLayer(in_channel, in_channel, 3),
                                       activation, ConvLayer(in_channel, out_channel, 3, downsample=downsample,upsample=upsample),
                                       ) 
        
        else:
             self.model= nn.Sequential(norm_layer(in_channel),activation, ConvLayer(in_channel, in_channel, 3),
                                       norm_layer(in_channel),activation ,ConvLayer(in_channel, out_channel, 3, downsample=downsample,upsample=upsample),
                                       ) 
        if downsample:
            self.skip = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2),ConvLayer(in_channel, out_channel, 1, bias=False))
        elif upsample:
            self.skip = ConvLayer(in_channel, out_channel, 3, upsample=True, bias=False)
        else:
            if out_channel==in_channel:
                self.skip = nn.Identity()
            else:
                self.skip = ConvLayer(in_channel, out_channel, 1)
            
    def forward(self, input):
        out = self.model(input)

        skip = self.skip(input)
        out = (out + skip) / math.sqrt(2)

        return out
        
class StyleBlock(nn.Module):
    def __init__(self, in_channel, out_channel, style_dim=128,downsample=False,upsample=False,norm_layer=nn.BatchNorm2d,activation=nn.LeakyReLU()):
        super().__init__()

        if norm_layer is None:
             self.model= nn.Sequential(activation, ConvLayer(in_channel, in_channel, 3),
                                       activation, ConvLayer(in_channel, out_channel, 3, downsample=downsample,upsample=upsample),
                                       ) 
        
        else:
             self.model= nn.Sequential(norm_layer(in_channel),activation, ConvLayer(in_channel, in_channel, 3),
                                       norm_layer(in_channel),activation ,ConvLayer(in_channel, out_channel, 3, downsample=downsample,upsample=upsample),
                                       ) 
        if downsample:
            self.skip = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2),ConvLayer(in_channel, out_channel, 1, bias=False))
        elif upsample:
            self.skip = ConvLayer(in_channel, out_channel, 3, upsample=True, bias=False)
        else:
            if out_channel==in_channel:
                self.skip = nn.Identity()
            else:
                self.skip = ConvLayer(in_channel, out_channel, 1)
                
        self.to_style=nn.Sequential(nn.ReLU(),EqualLinear(style_dim,out_channel,bias_init=1))
        
    def forward(self, input,z):
        style=self.to_style(z)
        out = self.model(input)
        out=out*style.unsqueeze(2).unsqueeze(3)
        skip = self.skip(input)
        out = (out + skip) / math.sqrt(2)
        return out
        
class EncoderBlock(nn.Module):
    def __init__(self, in_channel, out_channel, downsample=False,upsample=False,norm_layer=nn.BatchNorm2d,activation=nn.ReLU(),small=False):
        super().__init__()

        if norm_layer is None:
              if small:
                 self.model= nn.Sequential(activation,ConvLayer(in_channel, in_channel, 1,resample_filter=None),
                                       activation,ConvLayer(in_channel, out_channel, 2, downsample=downsample,upsample=upsample,resample_filter=None),
                                       )
              else:
                 self.model= nn.Sequential(activation,ConvLayer(in_channel, in_channel, 3),
                                       activation,ConvLayer(in_channel, out_channel, 3, downsample=downsample,upsample=upsample),
                                       ) 
        
        else:
             if small:
                 self.model= nn.Sequential(norm_layer(in_channel),activation,ConvLayer(in_channel, in_channel, 1,resample_filter=None),
                                       norm_layer(in_channel),activation,ConvLayer(in_channel, out_channel, 2, downsample=downsample,upsample=upsample,resample_filter=None),
                                       )
             else:
                 self.model= nn.Sequential(norm_layer(in_channel),activation, ConvLayer(in_channel, in_channel, 3),
                                       norm_layer(in_channel),activation ,ConvLayer(in_channel, out_channel, 3, downsample=downsample,upsample=upsample),
                                       ) 
            
    def forward(self, input):
        out = self.model(input)

        return out
        
class UNet(nn.Module):
    def __init__(self,img_size=256,ngf=32,norm_layer=None):
       super().__init__()    
       log_size = int(math.log(img_size, 2))
       convs=[ConvLayer(3,ngf//2,1)]
       deconvs=[]   
       in_channel=ngf//2
       for i in range(log_size-5):
            out_channel=min(128,in_channel*2)
            convs.append(EncoderBlock(in_channel,out_channel,norm_layer=norm_layer,activation=nn.ReLU(),downsample=True))
            deconvs.append(ResBlock(2*out_channel,in_channel,norm_layer=norm_layer,activation=nn.ReLU(),upsample=True))
            in_channel=out_channel
            
       deconvs=deconvs[::-1]
       self.activ=ToRGB_wostyle(ngf//2,3,norm=norm_layer)
      
       self.convs=nn.ModuleList(convs)
       self.deconvs=nn.ModuleList(deconvs)
       
    def forward(self,out):
        res=256
        skip={}
        for i in range(len(self.convs)):
           out=self.convs[i](out)
           skip[res]=out
           res=res//2
           
        for i in range(len(self.deconvs)):
           res*=2
           out=self.deconvs[i](torch.cat([skip[res],out],1))
           
        return self.activ(out)

        
               
class Discriminator(nn.Module):
    def __init__(self, img_size=128, ngf=32,first_in_channel=3):     
       super().__init__() 
       log_size = int(math.log(img_size, 2))
       
       in_channel=ngf
       convs=[ConvLayer(first_in_channel,ngf,1)]
       in_channel=ngf
       for i in range(log_size-5):  
            out_channel=min(256,in_channel*2)
            convs.append(ResBlock(in_channel,out_channel,norm_layer=None,activation=nn.LeakyReLU(0.2),downsample=True))
            in_channel=out_channel
       self.convs = nn.Sequential(*convs)
       self.final_layer2 = nn.Sequential(EqualLinear(out_channel, out_channel//2),nn.LeakyReLU(0.2),EqualLinear(out_channel//2, 1))
                                          
    def forward(self,input):
       
       out=self.convs(input)
       out=out.permute(0,2,3,1)
       score=self.final_layer2(out)
       
       return score
       

class ImageGenerator(nn.Module):
     def __init__(self,img_size=128,ngf=32,style_dim=128,num_downsamples=2,norm_layer=None):
       super().__init__()
       log_size = int(math.log(img_size, 2))
       convs=[ConvLayer(3+3,ngf,1)]
       pose_convs=[ConvLayer(2+18,ngf,1)]
       dis_convs=[ConvLayer(3+2+18,ngf,1)]
       
       convs_small=[ConvLayer(3+3,ngf,1)]
       dis_convs_small=[ConvLayer(3+2+18,ngf,1)]
       deconvs=[]
       convs_res=[]
       deconvs_res=[]
       in_channel=ngf
       for i in range(1,log_size-1):
            if i<=num_downsamples:
                 out_channel=min(style_dim,in_channel*2)
                 convs.append(EncoderBlock(in_channel,out_channel,norm_layer=norm_layer,activation=nn.ReLU(),downsample=True))
                 deconvs.append(StyleBlock(out_channel,in_channel,style_dim=style_dim,norm_layer=norm_layer,activation=nn.ReLU(),upsample=True))
                 
                 pose_convs.append(EncoderBlock(in_channel,out_channel,norm_layer=norm_layer,activation=nn.ReLU(),downsample=True))
                 dis_convs.append(EncoderBlock(in_channel,out_channel,norm_layer=norm_layer,activation=nn.ReLU(),downsample=True))#
                 
                 convs_small.append(EncoderBlock(in_channel,out_channel,norm_layer=norm_layer,activation=nn.ReLU(),downsample=True,small=True))
                 dis_convs_small.append(EncoderBlock(in_channel,out_channel,norm_layer=norm_layer,activation=nn.ReLU(),downsample=True,small=True))
                 
            else:
                 res=img_size//(2**num_downsamples)
                 block=nn.ModuleList([TransAttentionBlock(out_channel,8,res,res),TransAttentionBlock(out_channel,8,res,res)])
                 deconvs.append(block)
                 deconvs_res.append(StyleBlock(2*out_channel,in_channel,style_dim=style_dim,norm_layer=norm_layer,activation=nn.ReLU()))#2*
            in_channel=out_channel
       
       deconvs=deconvs[::-1]
       self.convs = nn.ModuleList(convs) #1 by 1
       self.deconvs = nn.ModuleList(deconvs) 
       self.pose_convs = nn.ModuleList(pose_convs)
       self.dis_convs=nn.ModuleList(dis_convs)
       self.activ=ToRGB_wostyle(ngf,3,norm=norm_layer)
       
       self.bgnet=UNet(ngf=ngf)
       self.mlp1=nn.Sequential(EqualLinear(style_dim,style_dim//2),nn.ReLU(),EqualLinear(style_dim//2,style_dim))
       self.convs_small=nn.ModuleList(convs_small)
       self.dis_convs_small=nn.ModuleList(dis_convs_small)
       
       self.deconvs_res=nn.ModuleList(deconvs_res)
       self.num_downsamples=num_downsamples
       
     def forward(self,inputs,pose,z,bg):
         
         
         
         b,c,h,w=inputs.shape
         permuted_tex=inputs[:,:6]
         permuted_pose=inputs[:,6:]
         
         out_pose=pose
         
         permuted_tex_small=inputs[:,:6]
         permuted_pose_small=inputs[:,6:]
         
         for i in range(len(self.pose_convs)):
             
              out_pose=self.pose_convs[i](out_pose)         
             
              permuted_tex=self.convs[i](permuted_tex) 
              permuted_pose=self.dis_convs[i](permuted_pose)                 
              
              permuted_tex_small=self.convs_small[i](permuted_tex_small) 
              permuted_pose_small=self.dis_convs_small[i](permuted_pose_small)
              
         z=self.mlp1(z)
        
         for i in range(len(self.deconvs)): 
             if i<len(self.deconvs)-self.num_downsamples:
               if i==len(self.deconvs)-self.num_downsamples-1:
                   out_pose2,_=self.deconvs[i][1](out_pose,permuted_pose_small,permuted_tex_small,return_att=True)
                   out_pose1,_=self.deconvs[i][0](out_pose,permuted_tex_small,permuted_tex_small,return_att=True)
               else:
                  out_pose2,_=self.deconvs[i][1](out_pose,permuted_pose,permuted_tex,return_att=True)
                  out_pose1,_=self.deconvs[i][0](out_pose,permuted_tex,permuted_tex,return_att=True)
                  
               out_pose=self.deconvs_res[i](torch.cat([out_pose1,out_pose2],1),z)
             else:
               out_pose=self.deconvs[i](out_pose,z)
               
         pred_mask=F.sigmoid(out_pose[:,0:1])     
         out=out_pose
         bg=self.bgnet(bg)
         
         return self.activ(out)*pred_mask+bg*(1-pred_mask),pred_mask
         
