import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from inspect import isfunction
from einops import rearrange, repeat


class Swish(nn.Module):
    def __init__(self):
        super(Swish,self).__init__()
        
    def forward(self,x):
        out = x * torch.sigmoid(x)
        return out 

class Downsample(nn.Module):
    def __init__(self,in_channels):
        super(Downsample,self).__init__()
        self.conv = nn.Conv2d(in_channels,in_channels,3,2,1, padding_mode='replicate')
    def forward(self,x):
        x = self.conv(x)
        return x 

class Upsample(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.upsample = nn.ConvTranspose2d(in_channels, in_channels, 4, 2, 1)
        self.conv = torch.nn.Conv2d(in_channels,in_channels,3,1,1, padding_mode='replicate')

    def forward(self, x):
        x = self.upsample(x)
        x = self.conv(x)
        return x
        

class ResnetBlock(nn.Module):
    def __init__(self,in_channels,out_channels,group=32):
        super(ResnetBlock,self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.norm1 = nn.GroupNorm(group,in_channels,eps=1e-6)
        self.conv1 = nn.Conv2d(in_channels,out_channels,3,1,1, padding_mode='replicate')
        self.norm2 = nn.GroupNorm(group,out_channels,eps=1e-6)
        self.conv2 = nn.Conv2d(out_channels,out_channels,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
        if in_channels != out_channels:
            self.nin_shortcut = nn.Conv2d(in_channels,out_channels,1,1) 
    def forward(self,x):
        h = x 
        h = self.norm1(h)
        h = self.act(h)
        h = self.conv1(h)
        h = self.norm2(h)
        h = self.act(h)
        h = self.conv2(h)
        
        if self.in_channels != self.out_channels:
            x = self.nin_shortcut(x)
        return x + h 


def Normalize(in_channels):
    return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

        
class SpatialSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = Normalize(in_channels)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b,c,h,w = q.shape
        q = rearrange(q, 'b c h w -> b (h w) c')
        k = rearrange(k, 'b c h w -> b c (h w)')
        w_ = torch.einsum('bij,bjk->bik', q, k)

        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        v = rearrange(v, 'b c h w -> b c (h w)')
        w_ = rearrange(w_, 'b i j -> b j i')
        h_ = torch.einsum('bij,bjk->bik', v, w_)
        h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
        h_ = self.proj_out(h_)
        return x+h_

class MidExFeature(nn.Module):
    def __init__(self,in_channels=3, base_channels=96,  f_channels=8, group=32):
        super(MidExFeature,self).__init__()
        self.conv_in = nn.Conv2d(in_channels,base_channels,3,1,1, padding_mode='replicate')
        
        self.mid = nn.Module()
        block_in_channels = base_channels
        self.mid.block_1 = ResnetBlock(block_in_channels,block_in_channels)
        self.mid.attn_1 = SpatialSelfAttention(block_in_channels)
        self.mid.block_2 = ResnetBlock(block_in_channels,block_in_channels)
        
        # 
        self.norm_out = nn.GroupNorm(group,block_in_channels,eps=1e-6)
        self.conv_out = nn.Conv2d(block_in_channels,f_channels,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
    def forward(self,x):
        h = self.conv_in(x)
        
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)
        
        h = self.norm_out(h)
        h = self.act(h)
        h = self.conv_out(h)
        return h 
        
class ExFeature(nn.Module):
    def __init__(self,in_channels=3, base_channels=96, ch_mult=[1,2,4], num_res_blocks=2, f_channels=8, group=32):
        super(ExFeature,self).__init__()
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.conv_in = nn.Conv2d(in_channels,base_channels,3,1,1, padding_mode='replicate')
        
        in_ch_mult = [1] + ch_mult
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            
            block_in_channels = base_channels * in_ch_mult[i_level]
            block_out_channels = base_channels * ch_mult[i_level]
            
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(block_in_channels,block_out_channels))
                block_in_channels = block_out_channels
            down = nn.Module()
            down.block = block 
            if i_level != self.num_resolutions - 1:
                down.downsample = Downsample(block_in_channels)
            self.down.append(down)
                
        self.mid = nn.Module()
        block_in_channels = base_channels*ch_mult[-1]
        self.mid.block_1 = ResnetBlock(block_in_channels,block_in_channels)
        self.mid.attn_1 = SpatialSelfAttention(block_in_channels)
        self.mid.block_2 = ResnetBlock(block_in_channels,block_in_channels)

        self.norm_out = nn.GroupNorm(group,block_in_channels,eps=1e-6)
        self.conv_out = nn.Conv2d(block_in_channels,f_channels,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
    def forward(self,x):
        hs = [self.conv_in(x)]
        
        for i_level in range(self.num_resolutions):
            for i_block in range(self.num_res_blocks):
                h = self.down[i_level].block[i_block](hs[-1])

                hs.append(h)
            
            if i_level != self.num_resolutions - 1:
                hs.append(self.down[i_level].downsample(hs[-1]))
        
        h = hs[-1]
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)
        
        h = self.norm_out(h)
        h = self.act(h)
        h = self.conv_out(h)
        return h 

class ReFeature(nn.Module):
    def __init__(self,out_channels=3,base_channels=96,ch_mult=[1,2,4],num_res_blocks=2,f_channels=4,group=32):
        super(ReFeature,self).__init__()
        
        self.num_reslutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        block_in_channels = base_channels*ch_mult[-1]
        self.conv_in = nn.Conv2d(f_channels,block_in_channels,3,1,1, padding_mode='replicate')
        
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(block_in_channels,block_in_channels)
        self.mid.attn_1 = SpatialSelfAttention(block_in_channels)
        self.mid.block_2 = ResnetBlock(block_in_channels,block_in_channels)

        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_reslutions)):
            block = nn.ModuleList()
            block_out_channels = base_channels * ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(block_in_channels,block_out_channels))
                block_in_channels = block_out_channels
                
            up = nn.Module()
            up.block = block 
            if i_level != 0:
                up.upsample = Upsample(block_in_channels)
            self.up.insert(0, up) 
        
        self.norm_out = nn.GroupNorm(group,block_in_channels,eps=1e-6)
        self.conv_out = nn.Conv2d(block_in_channels,256,3,1,1, padding_mode='replicate')
        self.act = Swish()
        
    def forward(self,z):
        h = self.conv_in(z)
        h = self.mid.block_1(h)
        h = self.mid.attn_1(h)
        h = self.mid.block_2(h)

        for i_level in reversed(range(self.num_reslutions)):
            for i_block in range(self.num_res_blocks):
                h = self.up[i_level].block[i_block](h)

            if i_level != 0:
                h = self.up[i_level].upsample(h)  
                
        h = self.norm_out(h)
        h = self.act(h)
        h = self.conv_out(h)
        return h 


class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.message_length = 64
        mid_channel = 256

        self.message0 = MidExFeature(in_channels=mid_channel, f_channels=mid_channel)
        self.message1 = MidExFeature(in_channels=mid_channel, f_channels=mid_channel)
        
        self.ex_image = ExFeature(f_channels=mid_channel)

        self.bit_mix_64 = MidExFeature(in_channels = mid_channel * self.message_length, f_channels=mid_channel)
        
        self.fmix = MidExFeature(in_channels = mid_channel+mid_channel, f_channels=mid_channel)
        
        self.re_image = ReFeature(f_channels=mid_channel)

        self.final_im_mix = torch.nn.Conv2d(mid_channel+3,3,3,1,1, padding_mode='replicate')

    def forward(self, image, message):
        B, L = message.shape 

        raw_encoded_image = self.ex_image(image)
        raw_msg0 = self.message0(raw_encoded_image.detach())
        raw_msg1 = self.message1(raw_encoded_image.detach()) 

        msg0 = raw_msg0.unsqueeze(1).expand(-1, L, -1, -1, -1)
        msg1 = raw_msg1.unsqueeze(1).expand(-1, L, -1, -1, -1)

        
        mask = (message).view(B, L, 1, 1, 1).float()
        expanded_message = (1 - mask) * msg0 + mask * msg1

        _, _, msg_C, msg_W, msg_H = expanded_message.shape

        expanded_message = expanded_message.view(B, L * msg_C, msg_W, msg_H)

        mix_message = self.bit_mix_64(expanded_message)

        concat1 = torch.cat([raw_encoded_image, mix_message], dim=1) 

        
        x = self.fmix(concat1)

        im_w = self.re_image(x)
        concat2 = torch.cat([im_w, image], dim=1)
        im_w = self.final_im_mix(concat2)

        return im_w