import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.checkpoint import checkpoint
from functools import partial

__all__ = ['UNet']

def sinusoidal_embedding(timesteps, dim):
    # check input
    half = dim // 2
    timesteps = timesteps.float()

    # compute sinusoidal embedding
    sinusoid = torch.outer(
        timesteps,
        torch.pow(10000, -torch.arange(half).to(timesteps).div(half)))
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
    if dim % 2 != 0:
        x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1)
    return x

def to_fp16(m):
    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        m.weight.data = m.weight.data.half()
        if m.bias is not None:
            m.bias.data = m.bias.data.half()

def to_fp32(m):
    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
        m.weight.data = m.weight.data.float()
        if m.bias is not None:
            m.bias.data = m.bias.data.float()

class GroupNorm(nn.GroupNorm):

    def forward(self, x):
        return super(GroupNorm, self).forward(x.float()).type(x.dtype)

class Resample(nn.Module):

    def __init__(self, in_dim, out_dim, scale_factor, use_conv=False):
        assert scale_factor in [0.5, 1.0, 2.0]
        super(Resample, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.scale_factor = scale_factor
        self.use_conv = use_conv

        # layers
        if scale_factor == 2.0:
            self.resample = nn.Sequential(
                nn.Upsample(scale_factor=scale_factor, mode='nearest'),
                nn.Conv2d(in_dim, out_dim, 3, padding=1) if use_conv else nn.Identity())
        elif scale_factor == 0.5:
            self.resample = nn.Conv2d(in_dim, out_dim, 3, stride=2, padding=1) if use_conv else nn.AvgPool2d(kernel_size=2, stride=2)
        else:
            self.resample = nn.Identity()
    
    def forward(self, x):
        return self.resample(x)

class ResidualBlock(nn.Module):

    def __init__(self, in_dim, embed_dim, out_dim, use_scale_shift_norm=True,
                 scale_factor=1.0, dropout=0.0):
        super(ResidualBlock, self).__init__()
        self.in_dim = in_dim
        self.embed_dim = embed_dim
        self.out_dim = out_dim
        self.use_scale_shift_norm = use_scale_shift_norm
        self.scale_factor = scale_factor

        # layers
        self.layer1 = nn.Sequential(
            GroupNorm(32, in_dim),
            nn.SiLU(),
            nn.Conv2d(in_dim, out_dim, 3, padding=1))
        self.resample = Resample(in_dim, in_dim, scale_factor)
        self.embedding = nn.Sequential(
            nn.SiLU(),
            nn.Linear(embed_dim, out_dim * 2 if use_scale_shift_norm else out_dim))
        self.layer2 = nn.Sequential(
            GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Dropout(dropout),
            nn.Conv2d(out_dim, out_dim, 3, padding=1))
        self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d(in_dim, out_dim, 1)

        # zero out the last layer params
        nn.init.zeros_(self.layer2[-1].weight)
    
    def forward(self, x, e):
        identity = self.resample(x)
        x = self.layer1[-1](self.resample(self.layer1[:-1](x)))
        e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype)
        if self.use_scale_shift_norm:
            scale, shift = e.chunk(2, dim=1)
            x = self.layer2[0](x) * (1 + scale) + shift
            x = self.layer2[1:](x)
        else:
            x = x + e
            x = self.layer2(x)
        x = x + self.shortcut(identity)
        return x

class AttentionBlock(nn.Module):

    def __init__(self, dim, num_heads=None, head_dim=None):
        # consider head_dim first, then num_heads
        num_heads = dim // head_dim if head_dim else num_heads
        head_dim = dim // num_heads
        assert num_heads * head_dim == dim
        super(AttentionBlock, self).__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.scale = math.pow(head_dim, -0.25)

        # layers
        self.norm = GroupNorm(32, dim)
        self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
        self.proj = nn.Conv2d(dim, dim, 1)

        # zero out the last layer params
        nn.init.zeros_(self.proj.weight)
    
    def forward(self, x):
        identity = x
        b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim

        # compute query, key, value
        x = self.norm(x)
        q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1)

        # compute attention
        attn = torch.einsum('bnci,bncj->bnij', q * self.scale, k * self.scale)
        attn = F.softmax(attn.float(), dim=-1).type(attn.dtype)

        # gather context
        x = torch.einsum('bnij,bncj->bnci', attn, v)
        x = x.reshape(b, c, h, w)

        # output
        x = self.proj(x)
        return x + identity

class UNet(nn.Module):

    def __init__(self,
                 in_dim=3,
                 dim=192,
                 out_dim=6,
                 dim_mult=[1, 2, 3, 4],
                 num_heads=4,
                 dim_scale=4,
                 out_dim_scale=4,
                 tstructure='shallow',
                 head_dim=64,
                 num_res_blocks=3,
                 attn_scales=[1 / 2, 1 / 4, 1 / 8],
                 num_classes=None,
                 resblock_resample=True,
                 use_conv=True,
                 use_scale_shift_norm=False,
                 use_fp16=False,
                 use_checkpoint=False,
                 dropout=0.1):
        embed_dim = dim * dim_scale
        out_embed_dim = dim * out_dim_scale
        super(UNet, self).__init__()
        self.in_dim = in_dim
        self.dim = dim
        self.out_dim = out_dim
        self.dim_mult = dim_mult
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.num_res_blocks = num_res_blocks
        self.attn_scales = attn_scales
        self.num_classes = num_classes
        self.resblock_resample = resblock_resample
        self.use_conv = use_conv
        self.use_scale_shift_norm = use_scale_shift_norm
        self.use_fp16 = use_fp16
        self.use_checkpoint = use_checkpoint

        # params
        enc_dims = [int(dim * u) for u in [1] + dim_mult]
        dec_dims = [int(dim * u) for u in [dim_mult[-1]] + dim_mult[::-1]]
        shortcut_dims = []
        scale = 1.0

        # embeddings
        assert tstructure in ['shallow', 'deep', 'verydeep']
        if tstructure == 'shallow':
            self.time_embedding = nn.Sequential(
                nn.Linear(dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, out_embed_dim))
        elif tstructure == 'deep':
            self.time_embedding = nn.Sequential(
                nn.Linear(dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, out_embed_dim))
        elif tstructure == 'verydeep':
            self.time_embedding = nn.Sequential(
                nn.Linear(dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, embed_dim),
                nn.SiLU(),
                nn.Linear(embed_dim, out_embed_dim))
        if num_classes is not None:
            self.label_embedding = nn.Embedding(num_classes, out_embed_dim)
        
        # encoder
        self.encoder = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)])
        shortcut_dims.append(dim)
        for i, (in_dim, out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])):
            for j in range(num_res_blocks):
                # residual (+attention) blocks
                block = nn.ModuleList([ResidualBlock(in_dim, out_embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)])
                if scale in attn_scales:
                    block.append(AttentionBlock(out_dim, num_heads, head_dim))
                shortcut_dims.append(out_dim)
                in_dim = out_dim
                self.encoder.append(block)

                # downsample
                if i != len(dim_mult) - 1 and j == num_res_blocks - 1:
                    if resblock_resample:
                        downsample = ResidualBlock(out_dim, out_embed_dim, out_dim, use_scale_shift_norm, 0.5, dropout)
                    else:
                        downsample = Resample(out_dim, out_dim, 0.5, use_conv=use_conv)
                    shortcut_dims.append(out_dim)
                    scale /= 2.0
                    self.encoder.append(downsample)
        
        # middle
        self.middle = nn.ModuleList([
            ResidualBlock(out_dim, out_embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout),
            AttentionBlock(out_dim, num_heads, head_dim),
            ResidualBlock(out_dim, out_embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)])
        
        # decoder
        self.decoder = nn.ModuleList()
        for i, (in_dim, out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])):
            for j in range(num_res_blocks + 1):
                # residual (+attention) blocks
                block = nn.ModuleList([ResidualBlock(in_dim + shortcut_dims.pop(), out_embed_dim, out_dim, use_scale_shift_norm, 1.0, dropout)])
                if scale in attn_scales:
                    block.append(AttentionBlock(out_dim, num_heads, head_dim))
                in_dim = out_dim

                # upsample
                if i != len(dim_mult) - 1 and j == num_res_blocks:
                    if resblock_resample:
                        upsample = ResidualBlock(out_dim, out_embed_dim, out_dim, use_scale_shift_norm, 2.0, dropout)
                    else:
                        upsample = Resample(out_dim, out_dim, 2.0, use_conv=use_conv)
                    scale *= 2.0
                    block.append(upsample)
                self.decoder.append(block)
        
        # head
        self.head = nn.Sequential(
            GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Conv2d(out_dim, self.out_dim, 3, padding=1))
        
        # zero out the last layer params
        nn.init.zeros_(self.head[-1].weight)

        # set precision
        if use_fp16:
            self.encoder.apply(to_fp16)
            self.middle.apply(to_fp16)
            self.decoder.apply(to_fp16)
    
    def forward(self, x, t, y=None):
        src_dtype = x.dtype
        tar_dtype = torch.float16 if self.use_fp16 else torch.float32

        # embeddings
        e = self.time_embedding(sinusoidal_embedding(t, self.dim))
        if self.num_classes is not None and y is not None:
            e = e + self.label_embedding(y)
        
        # encoder
        xs = []
        x = x.to(tar_dtype)
        for block in self.encoder:
            x = self._forward_single(block, x, e)
            xs.append(x)
        
        # middle
        for block in self.middle:
            x = self._forward_single(block, x, e)
        
        # decoder
        for block in self.decoder:
            x = torch.cat([x, xs.pop()], dim=1)
            x = self._forward_single(block, x, e)
        
        # head
        x = x.to(src_dtype)
        x = self.head(x)
        return x
    
    def _forward_single(self, module, x, e):
        if isinstance(module, (ResidualBlock, AttentionBlock)):
            module = partial(checkpoint, module) if self.use_checkpoint else module
            x = module(x, e) if isinstance(module, ResidualBlock) else module(x)
        elif isinstance(module, nn.ModuleList):
            for block in module:
                x = self._forward_single(block, x, e)
        else:
            x = module(x)
        return x