import math
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
from transformer import TransformerDecoder,TransformerEncoder
from utils import *
import pdb

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class SA_ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.1,attn=True):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )

        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()

        if attn:
            self.selfattn = TransformerEncoder(
            1, out_ch, 8, 0.1)
        else:
            self.selfattn = nn.Identity()

        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                if module.bias!=None:
                    init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x):
        '''Other structrue in LDM'''

        h = self.block1(x)

        h = self.block2(h)

        h = h + self.shortcut(x)

        B,C,H,W=h.shape
        h=h.permute(0,2,3,1).flatten(start_dim=1,end_dim=2) #B,H*W,C

        h=self.selfattn(h)
        h=h.permute(0,2,1).reshape(B,C,H,W)                 #B,C,H,W

        

        return h
    
class SA_DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x):
        x = self.main(x)
        return x


class SA_UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x
    
class SA_UNet(nn.Module):
    def __init__(self,vocab_size,  ch, ch_mult,  num_res_blocks=2, num_heads=8):
        
        super().__init__()
        self.head = nn.Conv2d(vocab_size, ch, kernel_size=4, stride=4, padding=0)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        '''attn in every layer'''
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(SA_ResBlock(
                    in_ch=now_ch, out_ch=out_ch,attn=False))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(SA_DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            SA_ResBlock(now_ch, now_ch,attn=True),
            SA_ResBlock(now_ch, now_ch,attn=True),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(SA_ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch,attn=False))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(SA_UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, ch, 3, stride=1, padding=1)
        )
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)

    def forward(self, x):
        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h)
        # Upsampling
        for layer in self.upblocks:
            if isinstance(layer, SA_ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h)
        h = self.tail(h)

        assert len(hs) == 0
        return h
