import torch 
import torch.nn as nn      


class Swish(nn.Module):
    def __init__(self):
        super(Swish,self).__init__()
        
    def forward(self,x):
        out = x * torch.sigmoid(x)
        return out 

class Downsample(nn.Module):
    def __init__(self,in_channels):
        super(Downsample,self).__init__()
        self.conv = nn.Conv2d(in_channels,in_channels,3,2,1, padding_mode='replicate')
    def forward(self,x):
        x = self.conv(x)
        return x 

class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1)
        self.conv = torch.nn.Conv2d(in_channels,in_channels,3,1,1, padding_mode='replicate')

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        return x
    

class ResnetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,group=32):
        super(ResnetBlock,self).__init__()      
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.norm1 = nn.GroupNorm(group,in_channels,eps=1e-6)
        self.conv1 = nn.Conv2d(in_channels,out_channels,3,1,1, padding_mode='replicate')
        self.norm2 = nn.GroupNorm(group,out_channels,eps=1e-6)
        self.conv2 = nn.Conv2d(out_channels,out_channels,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
        if in_channels != out_channels:
            self.nin_shortcut = nn.Conv2d(in_channels,out_channels,1,1) 
    def forward(self,x):
        h = x 
        h = self.norm1(h)
        h = self.act(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = self.act(h)
        h = self.conv2(h)
        
        if self.in_channels != self.out_channels:
            x = self.nin_shortcut(x)

        return x + h 


class SpatialSelfAttention(nn.Module):
    def __init__(self,in_channels,group=32):
        super(SpatialSelfAttention,self).__init__()
        self.in_channels = in_channels 
        self.norm = nn.GroupNorm(group,in_channels,eps=1e-6)
        self.q = nn.Conv2d(in_channels,in_channels,1,1)
        self.k = nn.Conv2d(in_channels,in_channels,1,1)
        self.v = nn.Conv2d(in_channels,in_channels,1,1)
        self.proj_out = nn.Conv2d(in_channels,in_channels,1,1)
    

    def forward(self,x):
        h_ = x 
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)
        
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w)
        q = q.permute(0,2,1) 
        k = k.reshape(b,c,h*w)
        w_ = torch.bmm(q,k)    
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   
        h_ = torch.bmm(v,w_)    
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_


class MidExFeature(nn.Module):
    def __init__(self,in_channels=3, base_channels=96, f_channels=8, group=32):
        super(MidExFeature,self).__init__()
        self.conv_in = nn.Conv2d(in_channels,base_channels,3,1,1, padding_mode='replicate')
        
        self.mid = nn.Module()
        block_in_channels = base_channels
        self.mid.block_1 = ResnetBlock(block_in_channels,block_in_channels)
        self.mid.attn_1 = SpatialSelfAttention(block_in_channels)
        self.mid.block_2 = ResnetBlock(block_in_channels,block_in_channels)
        
        self.norm_out = nn.GroupNorm(group,block_in_channels,eps=1e-6)
        self.conv_out30 = nn.Conv2d(block_in_channels,f_channels,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
    def forward(self,x):
        hs = [self.conv_in(x)]
        
        h = hs[-1]
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)
        
        h = self.norm_out(h)
        h = self.act(h)
        h = self.conv_out30(h)
        return h 

class ExFeature(nn.Module):
    def __init__(self,in_channels=3, base_channels=96, ch_mult=[1,2,4], num_res_blocks=2, f_channels=64, group=32):
        super(ExFeature,self).__init__()
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.conv_in = nn.Conv2d(in_channels,base_channels,3,1,1, padding_mode='replicate')
        
        in_ch_mult = [1] + ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            
            block_in_channels = base_channels * in_ch_mult[i_level]
            block_out_channels = base_channels * ch_mult[i_level]
            
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(block_in_channels,block_out_channels))
                block_in_channels = block_out_channels
            down = nn.Module()
            down.block = block 
            if i_level != self.num_resolutions - 1:
                down.downsample = Downsample(block_in_channels)
            self.down.append(down)
                
        self.mid = nn.Module()
        block_in_channels = base_channels*ch_mult[-1]
        self.mid.block_1 = ResnetBlock(block_in_channels,block_in_channels)
        self.mid.attn_1 = SpatialSelfAttention(block_in_channels)
        self.mid.block_2 = ResnetBlock(block_in_channels,block_in_channels)
        
        self.norm_out = nn.GroupNorm(group,block_in_channels,eps=1e-6)
        self.conv_out = nn.Conv2d(block_in_channels,f_channels,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
    def forward(self,x):
        hs = [self.conv_in(x)]
        
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1])

                hs.append(h)
            
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))
        
        h = hs[-1]
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)
        
        # 
        h = self.norm_out(h)
        h = self.act(h)
        h = self.conv_out(h)
        return h 



class ReFeature(nn.Module):
    def __init__(self,out_channels=256,base_channels=96,ch_mult=[1,2,4],num_res_blocks=2,f_channels=4,group=32):
        super(ReFeature,self).__init__()
        
        self.num_reslutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        block_in_channels = base_channels*ch_mult[-1]
        self.conv_in = nn.Conv2d(f_channels,block_in_channels,3,1,1, padding_mode='replicate')
        
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(block_in_channels,block_in_channels)
        self.mid.attn_1 = SpatialSelfAttention(block_in_channels)
        self.mid.block_2 = ResnetBlock(block_in_channels,block_in_channels)
        
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_reslutions)):
            block = nn.ModuleList()
            block_out_channels = base_channels * ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(block_in_channels,block_out_channels))
                block_in_channels = block_out_channels
                
            up = nn.Module()
            up.block = block 
            if i_level != 0:
                up.upsample = Upsample(block_in_channels)
            self.up.insert(0, up) 
        
        
        self.norm_out = nn.GroupNorm(group,block_in_channels,eps=1e-6)
        self.conv_out_256 = nn.Conv2d(block_in_channels,out_channels,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
    def forward(self,z):
        h = self.conv_in(z)
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)
        for i_level in reversed(range(self.num_reslutions)):
            for i_block in range(self.num_res_blocks):
                h = self.up[i_level].block[i_block](h)

            if i_level != 0:
                h = self.up[i_level].upsample(h)  
            
        h = self.norm_out(h)
        h = self.act(h)
        h = self.conv_out_256(h)
        
        return h 


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.message_length = 64 

        mid_channel = 256
        self.ex_image = ExFeature(f_channels=mid_channel)
        self.mid = MidExFeature(in_channels = mid_channel, f_channels=mid_channel)
        self.re_image = ReFeature(f_channels=mid_channel)
        self.pool = nn.AdaptiveAvgPool2d(1)

        self.norm_out_256 = nn.GroupNorm(32,256,eps=1e-6) 
        self.act = Swish()
        self.message_layer_64 = nn.Linear(256, 64)
    
    def forward(self, image_with_wm):
        B, _, _, _ = image_with_wm.shape
        
        x = self.ex_image(image_with_wm)
        x = self.mid(x)
        x = self.re_image(x)
        x = self.pool(x) 

        x = x.view(B, 256) 
        
        x = self.act(self.norm_out_256(x))
        x = self.message_layer_64(x)
        return x
    
    
    
        